Transcript of YouTube Video: Stanford CS25: V1 I Transformer Circuits, Induction Heads, In-Context Learning

Transcript of YouTube Video: Stanford CS25: V1 I Transformer Circuits, Induction Heads, In-Context Learning

The following is a summary and article by AI based on a transcript of the video "Stanford CS25: V1 I Transformer Circuits, Induction Heads, In-Context Learning". Due to the limitations of AI, please be careful to distinguish the correctness of the content.

Article By AIVideo Transcript
00:05

thank you all for having me it's

00:06

exciting to be here uh one of my

00:08

favorite things is talking about what is

00:11

going on inside neural networks or at

00:13

least what we what we're trying to

00:14

figure out is going on inside neural

00:15

networks so it's it's always fun to chat

00:17

about that

00:19

um

00:20

oh gosh i have to figure out how to how

00:22

to do things okay can i

00:24

what i i want okay there we go now now

00:27

we are advancing slides that seems wrong

00:30

um so i think interpretability means

00:32

lots of different things to different

00:34

people um it's a very a very broad term

00:36

and and people mean all sorts of

00:38

different things by it

00:39

um and so i wanted to talk just briefly

00:41

about uh the kind of interpretability

00:43

that i i spend my time thinking about um

00:45

which is what i'd call mechanistic

00:47

interpretability so

00:49

um most of my work actually has not been

00:51

on language models or on rnn's or

00:54

transformers but um on understanding

00:56

vision confidence and and trying to

00:58

understand how do the parameters in

01:00

those models actually map to algorithms

01:04

so you can like think of the parameters

01:06

of a neural network as being like a

01:07

compiled computer program and and the

01:10

neurons are kind of like variables or

01:11

registers and somehow there's there

01:14

there are these these complex computer

01:16

programs that are are embedded in those

01:18

weights and we'd like to turn them back

01:19

in to computer programs that that humans

01:21

can understand it's a kind of kind of

01:23

reverse engineering problem

01:25

um

01:26

and so this is this is kind of a

01:28

fun example that we found where there

01:29

was a car neuron and you could actually

01:31

see that um you know that we have the

01:33

car neuron and it's constructed from

01:35

like a wheel neuron

01:37

and it looks for in the case of the

01:39

wheel neuron it's looking for for the

01:40

wheels on the bottom and those are

01:42

positive weights and it doesn't want to

01:43

see them on top so there's negative

01:45

weights there and there's also a window

01:46

neuron it's looking for the windows on

01:48

the top and and not on the bottom and so

01:50

what we're actually seeing there right

01:51

is it's an algorithm it's an algorithm

01:54

that goes and turns um you know it's

01:57

it's just it's you know saying you know

01:58

well cars is has wheels on the bottom

02:00

and windows on the top and chrome in the

02:02

middle um and that's that's actually

02:03

like just the the strongest neurons for

02:05

that and so we're actually seeing a

02:07

meaningful algorithm and that's that's

02:08

not an exception that's that's sort of

02:10

the the general story that if you're

02:12

willing to go and look at neural neural

02:14

network weights and you're willing to

02:15

invest a lot of energy and trying to

02:16

first engineer them there's there's

02:18

meaningful algorithms written in the

02:20

weights waiting for you to find them

02:22

um and there's a bunch of reasons i

02:24

think that's an interesting thing to

02:25

think about one is you know just no one

02:27

knows how to go and do the things that

02:29

neural networks can do like no one knows

02:30

how to write a computer program that can

02:32

accurately classify imagenet let alone

02:33

you know the language modeling tasks

02:35

that we're doing no one knows how to

02:36

like directly write a computer program

02:38

that can do the things that gpd3 does

02:40

and yet somehow breaking descent is able

02:41

to go and discover a way to do this and

02:43

i want to know what's going on i want to

02:45

know you know how what is it discovered

02:48

that it can do in in these systems

02:51

there's another reason why i think this

02:53

is important which is uh is safety so

02:55

you know if we if we want to go and use

02:57

these systems in in places where they

02:59

have big effect on the world and

03:01

i think a question we need to ask

03:02

ourselves is you know what what happens

03:05

when these models have have

03:07

unanticipated failure modes failure

03:08

modes we didn't know to go and test for

03:10

or to look for to check for

03:12

how can we how can we discover those

03:13

things especially if they're they're

03:14

really pathological failure modes so the

03:16

models in some sense deliberately doing

03:17

something that we don't want well the

03:20

only way that i really see that we we

03:21

can do that is if we can get to a point

03:23

where we really understand what's going

03:24

on inside these systems

03:26

um so that's another reason that i'm

03:28

interested in this

03:30

now uh actually doing interpersonally on

03:32

language models and transformers it's

03:34

new to me i um before this year i spent

03:36

like eight years working on trying

03:38

reverse engineer confidence uh and

03:40

vision models um and so the ideas in

03:42

this talk um are are new things that

03:45

i've been thinking about with my

03:46

collaborators um and we're still

03:48

probably a month or two out maybe maybe

03:50

longer from publishing them um and this

03:52

is also the first public talk that i've

03:53

given on it so uh you know the things

03:55

i'm going to talk about um they made

03:57

they're i think honestly still a little

03:58

bit confused for me um and definitely

04:00

are going to be confused in my

04:01

articulation of them so if i say things

04:03

that are confusing um you know please

04:05

feel free to ask me questions there

04:06

might be some points for me to go

04:07

quickly because there's a lot of content

04:09

um but definitely at the end i will be

04:10

available for a while to chat about the

04:12

stuff um

04:14

and uh yeah also i apologize um if uh if

04:18

i'm unfamiliar with zoom and make make

04:20

mistakes um but

04:22

uh yeah so um with that said uh let's

04:25

dive in

04:26

um so i wanted to start with a mystery

04:31

um before we go and try to actually dig

04:34

into you know what's going on inside

04:35

these models um i wanted to motivate it

04:38

by a really strange piece of discover of

04:41

behavior that we discovered and and

04:42

wanted to understand

04:45

um

04:46

uh and by the way i should say all this

04:48

work is um uh you know is done with my

04:50

my colleagues anthropic and especially

04:52

my colleagues catherine and nelson

04:54

okay so on to the mystery

04:56

um i think probably the the most

04:58

interesting and most exciting thing

05:00

about um about transformers is their

05:03

ability to do in-context learning or

05:06

sometimes people call it meta-learning

05:08

um you know the gp3 paper uh goes and

05:10

and describes things as uh you know uh

05:13

language models are few shot learners

05:14

like there's lots of impressive things

05:15

about gp3 but they choose to focus on

05:17

that and you know now everyone's talking

05:18

about prompt engineering um and um

05:21

andrei caprathi was was joking about how

05:23

you know software 3.0 was designing the

05:25

prompt and so the ability of language

05:28

models of these these large transformers

05:29

to respond to their context and learn

05:32

from their context and change their

05:33

behavior and response to their context

05:35

and you know really seems like probably

05:37

the most surprising and striking and

05:38

remarkable thing about them

05:41

um

05:42

and

05:43

uh some of my my colleagues previously

05:45

published a paper that has a trick in it

05:47

that i i really love which is so we're

05:49

all used to looking at learning curves

05:51

you train your model and you you know as

05:52

your model trains the loss goes down

05:56

sometimes it's a little bit

05:57

discontinuous but it goes down

06:00

another thing that you can do is you can

06:02

go and take a fully trained model and

06:03

you can go and ask you know as we go

06:05

through the context you know as we go we

06:07

predict the first token and then the

06:08

second token and the third token we get

06:10

better at predicting each token because

06:12

we have more information to go and

06:13

predict it on so you know the first the

06:15

first con token the the loss should be

06:17

the the entropy of the unigrams and then

06:19

the next token should be the entry of

06:20

the biograms and it falls

06:22

but it keeps falling

06:24

and it keeps getting better

06:26

and

06:26

in in some sense that's our that's the

06:28

model's ability to go and predict to go

06:31

and do in-context learning the ability

06:34

to go and predict um you know to be

06:36

better at predicting later tokens than

06:37

you are predicting early tokens that is

06:39

that is in some sense a mathematical

06:40

definition of what it means to be good

06:42

at this magical in-context learning or

06:44

meta-learning that these models can do

06:46

and so that's kind of cool because that

06:48

gives us a a way to go and look at

06:50

whether models are good at in-context

06:52

learning

06:54

yeah if i could just ask the question

06:55

like a clarification question

06:57

please when you say learning there are

06:59

no actual parameters

07:02

that is the remarkable thing about

07:03

in-context learning right so yeah indeed

07:05

we traditionally think about neural

07:06

networks as learning over the course of

07:08

training by going and modifying their

07:10

parameters but somehow models appear to

07:12

also be able to learn in some sense um

07:14

if you give them a couple examples in

07:15

their context they can then go and do

07:17

that later in their context even though

07:18

no parameters changed and so it's it's

07:21

some kind of quite different different

07:22

notion of learning as you're as you're

07:24

gesturing that

07:26

uh

07:26

okay i think that's making more sense so

07:28

i mean could you also just describe in

07:31

context learning in this case as

07:32

conditioning as in like conditioning on

07:34

the first five tokens of a ten token

07:36

sentence

07:37

pretty cool tokens yeah i think the

07:39

reason that people sometimes think about

07:41

this as in context learning or meta

07:42

learning is that you can do things where

07:45

you like actually take a training set

07:46

and you embed the training set in your

07:48

context like if you just two or three

07:49

examples and then suddenly your model

07:51

can go and do do this task and so you

07:54

can do fuchsia learning by embedding

07:55

things in the context yeah the formal

07:59

setup is that you're you're just

08:00

conditioning on on on this context and

08:02

it's just that somehow this this ability

08:05

like this thing like there's there's

08:06

some sense you know for a long time

08:08

people were

08:09

were

08:09

i mean i i guess really the history of

08:11

this is uh

08:13

we started to get good at neural

08:14

networks learning right um and we could

08:17

we could go and train language uh train

08:18

vision models and language models that

08:19

could do all these remarkable things but

08:20

then people started to be like well you

08:22

know these systems are they take so many

08:24

more examples than humans do to go and

08:26

learn how can we go and fix this and we

08:28

had all these ideas about metal learning

08:29

develop where we wanted to go and and

08:32

train models

08:33

explicitly to be able to learn from a

08:34

few examples and people developed all

08:36

these complicated schemes and then the

08:37

like truly like absurd thing about about

08:39

transformer language models is without

08:41

any effort at all we get this for free

08:44

that you can go and just give them a

08:45

couple examples in their context and

08:47

they can learn in their context to go

08:48

and do new things um i think that was

08:51

like like that was in some sense the

08:52

like most striking thing about the gpd3

08:54

paper

08:55

um

08:56

and so uh this this yeah this ability to

08:59

go and have the just conditioning on a

09:01

context go and give you you know new

09:03

abilities for free and and the ability

09:05

to generalize to new things is in some

09:07

sense the the most yeah and to me the

09:09

most striking and shocking thing about

09:11

about transformer language models

09:14

that makes sense i mean i guess

09:17

from my perspective

09:18

i'm trying to square like

09:21

the notion of learning in this case

09:23

with you know if you or i were given a

09:25

prompt of like one plus one equals two

09:27

two plus three equals five

09:29

as the sort of few shot set up and then

09:33

somebody else put you know like five

09:34

plus three equals and we had to fill it

09:36

out in that case i wouldn't say that

09:38

we've learned arithmetic because we

09:40

already sort of knew it but rather we're

09:42

just sort of conditioning on the prompt

09:44

to know what it is that we should then

09:46

generate right

09:47

uh but it seems to me like that's

09:50

yeah i think that's on the spectrum

09:52

though because you can you can also go

09:53

and give like completely nonsensical

09:55

problems where the model would never

09:57

have seen um see like mimic this

10:00

function and give a couple examples of

10:01

the function and the model's never seen

10:02

it before and i can go and do that later

10:04

in the context um and i think i think

10:06

what you did learn um in a lot of these

10:08

cases you might not have you might have

10:10

um

10:11

you might not have learned arithmetic

10:12

like you might have had some innate

10:13

faculty for arithmetic that you're using

10:15

but you might have learned oh okay right

10:16

now we're doing arithmetic problems

10:19

um

10:20

got it in the case this is i agree that

10:21

there's like an element of semantics

10:22

here um yeah you know this is helpful

10:25

though just to clarify exactly sort of

10:26

what the

10:27

yeah what you remember

10:29

thank you for watching of course

10:33

so something that's i think really

10:34

striking about all of us

10:36

um

10:37

is well okay so we we've talked about

10:38

how we can we can sort of look at the

10:40

learning curve and we can also look at

10:41

this in context learning curve but

10:43

really those are just two slices of a

10:45

two-dimensional uh space so like the the

10:49

in some sense the more fundamental thing

10:50

is how good are we at producing the nth

10:52

token at a different given point in

10:53

training

10:54

and something that you'll notice if you

10:56

if you look at this and so when we when

10:58

we talk about the loss curve we're just

11:00

talking about if you average over this

11:01

dimension

11:02

if you if you like average like this and

11:04

and project on to the the training step

11:07

that's that's your loss curve um and um

11:10

if you the thing that we are calling the

11:11

incontext learning curve is just this

11:12

line um

11:15

uh yeah this this line uh down the end

11:18

here

11:20

um

11:21

and something that's that's kind of

11:22

striking is there's there's this

11:24

discontinuity in it um like there's this

11:26

point where where you know the model

11:29

seems to get radically better in a very

11:31

very short time step span at going and

11:33

predicting late tokens

11:35

so it's not that different in early time

11:37

steps but in late time steps suddenly

11:39

you get better

11:43

and a way that you can make this more

11:45

striking is you can you can take the

11:47

difference in in your ability to put the

11:48

50th token and your ability to predict

11:51

the 500 token you can subtract from the

11:53

the the 500th token the 50th token loss

11:56

and what you see

11:58

um is that over the course of trading

12:01

you know you're you're you're not very

12:02

good at this and you get a little bit

12:03

better and then suddenly

12:05

you have this cliff and then you never

12:07

get better than the difference between

12:09

these at least never gets better so the

12:10

model gets better at predicting things

12:12

but its ability to go and predict late

12:14

tokens over early tokens never gets

12:16

better

12:17

and so there's in the span of just a few

12:19

hundred steps in training the model has

12:21

gotten radically better at its ability

12:24

to go and and do this kind of in context

12:26

learning

12:28

and so you might ask you know what's

12:30

going on at that point

12:31

um and this is just one model but um

12:33

well so first of all it's worth noting

12:35

this isn't a small a small change and so

12:38

um

12:39

that you can we don't think about this

12:41

very often but you know often we just

12:42

look at loss goes more like did the

12:43

model do better than another model or

12:44

worse than another model but um you can

12:46

you can think about this as in terms of

12:48

gnats and that are are you know it's

12:50

just the information theoretic quantity

12:51

in that um and you can convert that into

12:54

bits and so like one one way you can

12:56

interpret this is it's it's something

12:58

roughly like you know the model 0.4 nas

13:01

is about 0.5 bits is about uh like every

13:03

other token the model gets to go and

13:04

sample twice um and pick the better one

13:07

it's actually it's even stronger than

13:08

that that's sort of an underestimate of

13:10

how big a deal going and getting better

13:12

by 0.4 an ounces so this is like a real

13:15

big difference in the model's ability to

13:16

go and predict late tokens

13:20

um and we can visualize this in

13:21

different ways we can we can also go and

13:23

ask you know how much better are we

13:24

getting at going and predicting later

13:26

tokens and look at the derivative and

13:28

then we we can see very clearly that

13:29

there's there's some kind of

13:30

discontinuity in that derivative at this

13:32

point and we can take the second

13:34

derivative then and we can um with well

13:37

derivative with respect to training and

13:39

now we see that there's like there's

13:40

very very clearly this this line here so

13:43

something in just the span of a few

13:45

steps a few hundred steps is is causing

13:47

some big change and we have some kind of

13:50

phase change going on

13:51

um and this is true across model sizes

13:54

um

13:55

uh you can you can actually see it a

13:56

little bit in the loss curve and there's

13:58

this little bump here and that

13:59

corresponds to the point where you have

14:01

this you have this change we we actually

14:03

could have seen in the lost curve

14:04

earlier too it's it's this bump here

14:08

excuse me so so we have this phase

14:10

change going on and there's a i think a

14:12

really tempting theory to have which is

14:14

that somehow whatever you know there's

14:16

some this this change in the model's

14:18

output and its behaviors and it's in a

14:20

in a in in these sort of outward facing

14:22

properties corresponds presumably to

14:24

some kind of change in the algorithms

14:26

that are running inside the model so if

14:28

we observe this big phase change

14:29

especially in a very small window um in

14:32

in the model's behavior presumably

14:34

there's some change in the circuits

14:35

inside the model that is driving that

14:38

at least that's a you know a natural

14:39

hypothesis so

14:41

um if we want to ask that though we need

14:42

to go and be able to understand you know

14:44

what are the algorithms that's running

14:45

inside the model how can we turn the

14:47

parameters in the model back into those

14:48

algorithms so that's going to be our

14:49

goal

14:51

um now it's going to recover us require

14:53

us to cover a lot of ground um in a

14:55

relatively short amount of time so i'm

14:56

going gonna go a little bit quickly

14:58

through the next section and i will

14:59

highlight sort of the the key takeaways

15:02

and then i will be very happy um to go

15:04

and uh you know explore any of this in

15:07

as much depth i'm free for another hour

15:09

after this call um and just happy to

15:11

talk in as much depth as people want

15:12

about the details of this

15:14

so

15:15

um it turns out this phase change

15:17

doesn't happen in a one layer attention

15:19

only transformer and it does happen in a

15:21

two-layer attention-only transformer so

15:23

if we could understand a one-layer

15:24

attenuate transformer and a two layer

15:26

only attention potentially transformer

15:28

that might give us a pretty big clue as

15:30

to what's going on

15:32

um

15:34

so we're attention only we're also going

15:35

to leave out layer and biases to

15:37

simplify things so you know you one way

15:39

you could describe a attention only

15:42

transformer

15:43

is we're going to embed our tokens

15:45

and then we're going to apply a bunch of

15:46

attention heads and add them into the

15:48

residual stream and then apply our

15:50

unembedding and that'll give us our

15:51

logits

15:53

and we can go and write that out as

15:54

equations if we want multiplied by an

15:55

embedding matrix

15:57

apply attention heads

15:59

and then compute the logics from the

16:00

unembedding

16:03

um

16:04

and the part here that's a little tricky

16:06

is understanding the attention ads and

16:08

this might be a somewhat conventional

16:10

way of describing attention and it

16:12

actually kind of obscures a lot of the

16:14

structure of attentionnets and i think

16:15

that oftentimes it's we we make

16:18

attention heads more more complex than

16:19

they are we sort of hide the interesting

16:21

structure

16:22

so what is this saying let's say you

16:23

know for every token compute a value

16:25

back a value vector and then go and mix

16:28

the value vectors according to the

16:29

attention matrix and then project them

16:31

with the output matrix back into the

16:32

residual string

16:34

um so there's there's another notation

16:37

which you could think of this as a as

16:39

using tensor products or using um using

16:42

uh

16:43

well i guess there's a few left and

16:45

right multiplying there's a few ways you

16:46

can interpret this but um

16:48

i'll just sort of try to explain what

16:50

this notation means um

16:53

so this means

16:54

for every you know x or our residual

16:56

string we have a vector for every single

16:59

token

17:00

and this means go and multiply

17:02

independently the vector for each token

17:04

by wv so compute the value vector for

17:07

every token

17:09

this one on the other hand means notice

17:10

that it's now on the a is on the left

17:12

hand side it means go and go and

17:14

multiply

17:15

the

17:17

attention matrix or go and go into

17:19

linear combinations of the values value

17:20

vectors so don't don't change the value

17:22

vectors you know point wise but go and

17:24

mix them together according to the

17:25

attention pattern create a weighted sum

17:28

and then again independently for every

17:30

position go and apply the output matrix

17:33

and you can apply the distributive

17:34

property to this and it just reveals

17:36

that actually it didn't matter that you

17:37

did the attention sort of in the middle

17:39

you could have done the attention at the

17:40

beginning you could have done it at the

17:41

end um that's that's independent um and

17:44

the thing that actually matters is

17:45

there's this wvwo matrix that describes

17:48

what it's really saying is you know

17:50

wvw describes what information the

17:52

attention head reads from each position

17:54

and how it writes it to its destination

17:56

whereas a describes which tokens we read

17:59

from and write to

18:00

um and that's that's kind of getting

18:02

more the fundamental structure and

18:03

attention an attention head goes and

18:05

moves information from one position to

18:07

another and the process of of which

18:10

position gets moved from and two is

18:11

independent from what information gets

18:13

moved

18:17

and if you rewrite your transformer that

18:19

way

18:20

well first we can go and write uh the

18:22

sum of attention heads just as as in

18:25

this form

18:26

um and then we can uh go and write that

18:29

as the the entire layer by going and

18:31

adding an identity

18:34

and if we go and plug that all in to our

18:36

transformer and go and expand

18:40

um we we have to go and multiply

18:43

everything through we get this

18:45

interesting equation and so we get this

18:47

one term this corresponds to just the

18:49

path directly through the residual

18:50

stream

18:51

and it's going to want to store uh

18:53

bigram statistics it's just you know all

18:55

i guess is the previous token and tries

18:56

to predict the next token

18:58

and so it gets to try and predict uh try

19:00

to store bi-gram statistics and then for

19:02

every attention head we get this matrix

19:04

that says okay well for we have the

19:05

attention pattern so it looks that

19:07

describes which token looks at which

19:08

token and we have this matrix here which

19:10

describes how for every possible token

19:12

you could attend to

19:13

how it affects the logics and that's

19:16

just a table that you can look at it

19:17

just says you know for for this

19:18

attention head if it looks at this token

19:20

it's going to increase the probability

19:21

of these tokens in a one layer attention

19:23

only transformer that's all there is

19:28

um yeah so this is just just the

19:30

interpretation i was describing

19:32

um

19:34

and another thing that's worth noting is

19:36

this

19:36

the according to this the attention on

19:38

the transformer is linear if you fix the

19:41

attention pattern now of course it's the

19:43

attention pattern isn't fixed but

19:44

whenever you even have the opportunity

19:46

to go and make something linear linear

19:47

functions are really easy to understand

19:49

and so if you can fix a small number of

19:50

things and make something linear that's

19:51

actually

19:52

a lot of leverage

19:54

okay

19:56

um

19:57

and yeah we can talk about how the

19:59

attention pattern is computed as well

20:01

um you if you expand it out you'll get

20:03

an equation like this

20:05

and uh notice well i think i think it'll

20:08

be easier

20:09

okay

20:12

the i think the core story though to

20:14

take away from all of these is we have

20:16

these two matrices that actually look

20:17

kind of similar so

20:19

this one here

20:20

tells you if you attend to a token

20:22

how are the logits affected

20:24

and it's you can just think of it as a

20:25

giant matrix of for every possible token

20:28

input token how how is the logic how are

20:30

the logics affected

20:31

by that token are they made more likely

20:33

or less likely

20:34

and we have this one which sort of says

20:36

how much does every token want to attend

20:38

to every other token

20:42

um one way that you can you can picture

20:44

this is

20:46

uh okay that's really there's really

20:48

three tokens involved when we're

20:49

thinking about an attention head we have

20:51

the

20:52

token that

20:54

we're going to move information to and

20:56

that's attending backwards

20:58

we have the source token that's going to

21:00

get attended to and we have the output

21:02

token whose logits are going to be

21:03

affected

21:04

and you can just trace through this so

21:06

you can ask what happens um how does the

21:09

the attending to this token affect the

21:11

output well first we embed the token

21:14

then we multiply by wv to get the value

21:17

vector the information gets moved by the

21:19

attention pattern

21:20

we multiply by wo to add it back into

21:22

the residual stream we get hit by the

21:24

unembedding and we affect the logits and

21:26

that's where that one matrix comes from

21:28

and we can also ask you know what

21:29

decides you know whether a token gets a

21:31

high score when we're when we're

21:33

computing the attention pattern and it

21:34

just says

21:35

you know embed embed the token

21:39

turn it into a query embed the other

21:40

token turn it into a key

21:43

and dot product to them and see you

21:44

that's where those those two matrices

21:46

come from

21:47

so i know that i'm going quite quickly

21:50

um

21:51

maybe i'll just briefly pause here and

21:54

if anyone wants to ask for

21:55

clarifications uh this would be a good

21:57

time and then we'll actually go and

21:59

reverse engineer and and say you know

22:01

everything that's going on in a

22:02

one-layer pendulum transformer is now in

22:03

the palm of our hands

22:05

it's a very toy model

22:07

you know one actually uses one layer

22:09

attention on the transformers but we'll

22:10

be able to understand the one layer

22:12

attention only transformer

22:16

so just to be clear so you're saying

22:17

that yes the the

22:19

quite key circuit is learning the

22:21

attention weights

22:23

and like essentially it's responsive

22:24

running the sort of attention between

22:26

different uh tokens

22:27

yeah yeah so

22:29

so this this matrix when it yeah you

22:31

know all three of those parts are

22:33

learned but that's that's what expresses

22:35

whether

22:36

a attention pattern is yeah that's what

22:38

generates the attention patterns gets

22:40

run for every pair of tokens and you can

22:42

you can you can think of values in that

22:43

matrix as just being how much every

22:45

token wants to attend to every other

22:46

token if it was in the context and we're

22:49

we're drawing positional weddings here

22:50

so there's a little bit that we're sort

22:51

of aligning over there as well but sort

22:53

of in in a global sense how much does

22:55

every token want to attend every other

22:56

token right

22:57

and the other circuit like the output

22:59

value circuit is

23:00

using the attention that's calculated to

23:03

guess

23:05

like affect the final outputs it's sort

23:08

of saying if if the attention head

23:09

assume that the attention head attends

23:11

to some token so let's set aside the

23:12

question of how that gets computed just

23:14

assume that it hence to some token how

23:16

would it affect the outputs if it

23:17

attended to that token

23:19

and you can just you can just calculate

23:20

that um it's just a big table of values

23:23

that says you know for this token

23:24

it's going to make this token more

23:25

likely this token will make this token

23:27

less likely

23:29

right okay

23:31

and it's completely independent like

23:32

it's just two separate matrices they're

23:34

they're not you know the the formulas

23:37

that might make them seem entangled but

23:38

they're actually separate

23:40

all right so to me it seems like the

23:43

lecture supervision is coming from the

23:44

output value circuit and the query key

23:46

second seems to be more like

23:47

unsupervised kind of thing because

23:48

there's no

23:50

i mean they're just i think in the sense

23:53

that every in in an yeah in a model like

23:56

every every neuron is in some sense you

23:57

know like

23:59

signals is is somehow downstream from

24:02

the ultimate the ultimate signal and so

24:04

you know the output value signal the

24:06

output value circuit is getting more

24:08

more direct is perhaps getting more

24:09

direct signal correct um but yeah

24:12

yes

24:16

we will be able to dig into this in lots

24:18

of detail in as much detail as you want

24:20

uh in a little bit so we can um maybe

24:23

i'll push forward and i think also

24:24

actually an example of how to use this

24:26

reverse engineer one layer model will

24:28

maybe make it a little bit more more

24:29

motivated

24:31

okay so

24:32

um just just to emphasize this there's

24:35

three different tokens that we can talk

24:37

about there's a token that gets attended

24:38

to

24:39

there's the token that does the

24:40

attention which are called the

24:42

destination and then there's the token

24:43

that gets affected yet it gets the next

24:45

token which its probabilities are

24:46

affected

24:48

um

24:49

and so something we can do is notice

24:51

that the the only token that connects to

24:53

both of these is the token that gets

24:55

attended to

24:56

so these two are sort of they're they're

24:58

bridged

24:59

by their their interaction with the

25:01

source token so something that's kind of

25:02

natural is to

25:04

ask for a given source token you know

25:06

how does it interact with both of these

25:09

so let's let's take for instance the

25:11

token perfect

25:13

which tokens for one thing we can ask is

25:15

which tokens want to attend to perfect

25:19

well apparently the tokens that most

25:20

want to attend to perfect are are and

25:23

looks and is and provides

25:26

um so r is the most looks is the next

25:27

most and so on

25:29

and then when we attempt to perfect and

25:30

this is with one one single attention

25:32

head so you know it'd be different if we

25:33

did a different intention attention ed

25:35

it wants to really increase the

25:36

probability of perfect and then to a

25:38

lesser extent super and absolute and

25:40

pure and we can ask you know what what

25:44

sequences of tokens are made more likely

25:47

by this

25:48

this particular

25:49

um set of you know this particular set

25:51

of things wanting to attend to each

25:53

other and becoming more likely well

25:55

things are the form

25:57

we have our tokens we attended back to

25:59

and we have some

26:00

some skip of some number of tokens they

26:02

don't have to be adjacent but then later

26:03

on we see the token r and it attends

26:05

back to perfect and increases the

26:07

probability of perfect

26:09

so you can you can think of these as

26:10

being like we're sort of creating

26:12

changing the probability of what we

26:13

might call might call skip trigrams

26:15

where we have you know we skip over a

26:17

bunch of tokens in the middle but we're

26:18

affecting the probability really of of

26:20

trigrams

26:22

so perfect our perfect perfect look

26:23

super

26:25

um we can look at another one so we have

26:26

the token large

26:28

um these tokens contains using specify

26:30

want to go and look back to it and it

26:32

increases probability of large and small

26:34

and the skip trigrams that are affected

26:36

are things like large using large

26:39

large contained small

26:42

and things like this

26:44

um if we see the number two and we

26:46

increase the probability of other

26:47

numbers and we affect probable tokens or

26:51

skipped diagrams like two one two

26:53

two

26:54

has three

26:56

um

26:57

now you're you're all in uh in a

27:00

technical field so you'll probably

27:01

recognize this one we have uh have

27:02

lambda and then we see backslash and

27:06

then we want to increase the probability

27:07

of lambda and sorted and lambda and

27:09

operator so it's all fall latex

27:12

and it wants to and it's if it sees

27:14

lambda it thinks that you know maybe

27:15

next time i use a backslash i should go

27:17

and put in some latex

27:19

math symbol

27:21

um

27:22

also

27:22

same thing for html we see nbsp for

27:25

non-breaking space and then we see an

27:27

ampersand we want to go and make that

27:28

more likely

27:29

the takeaway from all this is that a one

27:31

layer attenuating transformer is totally

27:33

acting on these skip trigrams

27:36

um

27:37

every everything that it does i mean i

27:38

guess it also has this pathway by which

27:39

it affects bi-grams but mostly it's just

27:41

affecting these skiff trigrams

27:43

um and there's lots of them it's just

27:44

like these giant tables of skip trigrams

27:46

that are made more or less likely

27:50

um

27:51

there's lots of other fun things that

27:52

does sometimes the tokenization will

27:53

split up a word in multiple ways so um

27:56

like we have indie

27:57

well that's that's not a good example we

27:59

have like the word pike and then we

28:01

we see the the token p and then we

28:03

predict ike

28:05

um when we predict spikes and stuff like

28:07

that um

28:08

or

28:09

these these ones are kind of fun maybe

28:10

they're actually worth talking about for

28:11

a second so we see

28:13

the token void

28:15

and then we see an l and maybe we

28:17

predict lloyd

28:18

um or r and we predict ralph

28:21

and c catherine

28:23

and but we'll see in a second then well

28:25

yeah we'll come back to that in a sec so

28:27

we increase the probability of things

28:28

like lloyd's lloyd and lloyd catherine

28:30

or pixmap

28:32

if anyone's worked with qt

28:34

um it's we see pics map and we increase

28:36

the probability of um p

28:39

xmap again but also

28:41

q

28:42

canvas um

28:45

yeah

28:47

but of course there's a problem with

28:48

this which is um

28:50

it doesn't get to pick which one of

28:51

these goes with which one

28:52

so if you want to go and make pixmap

28:55

pixmap

28:56

and pixmap q canvas more probable you

28:59

also have to go and create make pixmac

29:01

pixmap p canvas

29:03

more probable

29:04

and if you want to make lloyd lloyd and

29:06

lloyd catherine

29:08

more probable you also have to make

29:09

lloyd cloyd and lloyd lathron

29:13

more probable

29:14

and so there's actually like bugs that

29:16

transformers have like weird at least

29:17

and you know and these these really tiny

29:19

one-layer attention only transformers

29:21

there there's these bugs that you know

29:22

they seem weird until you realize that

29:24

it's this giant table of skip trigrams

29:26

that's that's operating

29:28

um

29:28

and the the nature of that is that

29:30

you're going to be

29:31

um

29:33

uh yeah you it sort of forces you if you

29:35

want to go and do this to go in and also

29:37

make some weird predictions

29:40

chris

29:42

is there a reason why the source tokens

29:44

here have a space before the first

29:45

character

29:46

yes um that's just the i was giving

29:49

examples where the tokenization breaks

29:51

in a particular way and okay um because

29:54

spaces get included in the tokenization

29:57

um

29:58

when there's a space in front of

29:59

something and then there's an example

30:01

where the space isn't in front of it

30:02

they can get tokenized in different ways

30:04

got it cool thanks

30:06

great question

30:10

um

30:11

okay so some just to abstract away some

30:13

common patterns that we're seeing i

30:14

think um

30:16

one pretty common thing is what you

30:17

might describe as like d

30:19

a b so you're you go and you you see

30:22

some token and then you see another

30:23

token that might precede that token then

30:25

you're like ah probably the token that i

30:26

saw earlier is going to occur again

30:28

um or sometimes you you predict a

30:31

slightly different token so like maybe

30:33

maybe an example the first one is two

30:35

one two

30:36

but you could also do two

30:38

has three

30:39

and so three isn't the same as two but

30:41

it's kind of similar so that's that's

30:42

one thing another one is this this

30:44

example where you've a token that

30:45

something it's tokenized together one

30:46

time and that's split apart so you see

30:49

the token and then you see something

30:50

that might be the first part of the

30:51

token and then you predict the second

30:53

part

30:55

and

30:56

i think the thing that's really striking

30:58

about this is these are all in some ways

31:01

a really crude kind of in context

31:03

learning

31:04

and and

31:06

in particular these models get about 0.1

31:08

nouns rather than about 0.4

31:10

of incontex learning and they never go

31:12

through the phase change so they're

31:13

doing some kind of really crude and

31:15

context learning and also they're

31:16

dedicating almost all their attention

31:18

heads to this kind of recruiting context

31:20

learning so they're not very good at it

31:21

but they're they're they're dedicating

31:23

their um their capacity to it

31:26

uh i'm noticing that it's 10 37 um

31:29

i i want to just check how long i can go

31:31

because i maybe i should like super

31:32

accelerate because this is

31:34

chris uh i think it's fine because like

31:36

students are also asking questions in

31:38

between such uh you should be good

31:40

okay so maybe my plan will be that i'll

31:42

talk until like 10 55 or 11 and then if

31:45

you want i can go and answer questions

31:48

for a while after after that

31:50

yeah it works

31:51

fantastic

31:53

so you can see this as a very crude kind

31:55

of in context learning like basically

31:56

what we're saying is it's sort of all

31:58

this flavor of okay well i saw this

31:59

token probably these other tokens the

32:01

same token or similar tokens are more

32:03

likely to go and occur later and look

32:05

this is an opportunity that sort of

32:06

looks like i can inject the token that i

32:08

saw earlier i'm going to inject it here

32:09

and say that it's more likely that's

32:10

like that's basically what it's doing

32:13

and it's dedicating almost all of its

32:14

capacity to that so you know these it's

32:16

sort of the opposite of what we thought

32:17

with rnn's in the past like used to be

32:19

that everyone was like oh you know rnn's

32:20

it's so hard to care about long distance

32:23

contacts you know maybe we need to go

32:25

and like use dams or something no if you

32:27

if you train a transformer it dedicates

32:29

and you give it a long a long enough

32:30

context it's dedicating almost all of

32:32

its capacity um to this type of stuff um

32:35

just kind of interesting

32:38

um there are some attentions which are

32:40

more primarily positional um usually we

32:43

in the model that i've been training

32:44

that has two layer or it's only a one

32:46

layer model has twelve attention units

32:48

and usually around two or three of those

32:49

will become these more positional sort

32:51

of shorter term things that do something

32:52

more like like local trigram statistics

32:55

and then everything else becomes these

32:56

skipped programs

33:00

um yeah so uh some takeaways from this

33:03

uh yeah you can you can understand one

33:06

layer eventually transformers in terms

33:07

of these ov and qk circuits um

33:10

transformers desperately want to do

33:12

in-context learning they desperately

33:14

desperately desperately want to go and

33:16

and look at these long distance contacts

33:18

and go and predict things there's just

33:19

so much so much entropy that they can go

33:21

and reuse out of that

33:23

the constraints of a when they are

33:25

intentionally transformer force it to

33:26

make certain bugs but it wants to do the

33:27

right thing

33:28

um

33:29

and if you freeze the attention patterns

33:31

these models are linear

33:33

okay

33:34

um

33:35

a quick aside because so far this type

33:38

of work has required us to do a lot of

33:40

very manual inspection like we're

33:41

walking through these giant matrices but

33:43

there's a way that we can escape that we

33:44

don't have to use look at these giant

33:46

matrices if we don't want to

33:47

um we can use eigenvalues and

33:49

eigenvectors so recall that an

33:50

eigenvalue

33:52

and an eigenvector just means that if

33:54

you if you multiply that vector by the

33:56

matrix um it's equivalent to just

33:58

scaling

34:00

and

34:01

uh often

34:02

in my experience those haven't been very

34:03

useful for interpretability because

34:04

we're usually mapping between different

34:06

spaces but if you're mapping onto the

34:07

same space either values either vectors

34:09

are a beautiful way to think about this

34:11

um so we're going to draw them um

34:14

on a

34:15

a radial plot

34:17

um and we're going to have a log

34:19

uh radial scale because they're gonna

34:20

vary their magnitude's gonna vary in by

34:22

many orders of magnitude

34:25

um okay so we can just go and you know

34:27

our ob circuit maps from tokens to

34:29

tokens that's the same vector space and

34:30

the input and the output and we can ask

34:32

you know what does it mean if we see

34:34

eigenvalues of a particular kind well

34:36

positive eigenvalues and this is really

34:38

the most important part mean copying so

34:40

if you have a positive eigenvalue it

34:41

means that there's some set of of tokens

34:44

where if you if you see them you

34:46

increase their probability

34:47

and if you have a lot of positive

34:48

eigenvalues um you're doing a lot of

34:50

copying if you only have positive

34:52

eigenvalues everything you do is copying

34:54

um now imaginary eigenvalues mean that

34:56

you see a token and then you want to go

34:58

and increase the probability of

34:59

unrelated tokens and finally negative

35:01

eigenvalues are anti-copying they're

35:02

like if you see this token you make it

35:04

less probable in the future

35:07

well that's really nice because now we

35:08

don't have to go and dig through these

35:09

giant matrices that are vocab size by

35:11

vocab size we can just look at the

35:13

eigenvalues

35:14

um and so these are the eigenvalues for

35:16

our one layer attention only transformer

35:18

and we can see that you know

35:21

for

35:22

many of these they're almost entirely

35:24

positive these events are are sort of

35:26

entirely positive these ones are almost

35:28

entirely positive and really these ones

35:30

are even almost entirely positive and

35:31

there's only two

35:33

that have a significant number of

35:34

imaginary and negative eigenvalues

35:37

um and so what this is telling us is

35:39

it's just in one picture we can see you

35:40

know okay they're really you know

35:44

10 out of 12 of these of these attention

35:46

heads are just doing copying they just

35:47

they just are doing this long distance

35:48

you know well i saw a token probably

35:50

it's going to occur again type stuff um

35:52

that's kind of cool we can we can

35:53

summarize it really quickly

35:56

okay

35:57

um

35:58

now the other thing that you can yeah so

36:00

this is this is for a second we're gonna

36:01

look at a two-layer model in a second

36:03

and we'll we'll see that also a lot of

36:04

its heads are doing this kind of copying

36:06

or stuff they have large positive

36:07

eigenvalues

36:10

um you can do a histogram like you know

36:12

one one thing that's cool is you can

36:13

just add up the uh the eigenvalues and

36:15

divide them by their absolute values and

36:16

you get a number between zero and one

36:18

which is like how copying how copying is

36:19

just the head or between negative one

36:21

and one how copying is just the head you

36:22

can just do a histogram you can see oh

36:24

yeah almost all the heads are doing

36:26

doing lots of copying

36:28

you know it's nice to be able to go and

36:29

summarize your model you know uh and i

36:31

think this is this is sort of like we've

36:32

gone for a very bottom-up way and we

36:35

didn't start with assumptions about what

36:36

model is doing we tried to understand

36:37

its structure and then we were able to

36:38

summarize it in useful ways and now

36:40

we're able to go and say something about

36:41

it

36:43

um now another thing you might ask is

36:45

what what do the the eigenvalues of the

36:46

qk circuit mean and in our example so

36:48

far they haven't been that they wouldn't

36:50

have been that interesting but in a

36:52

minute they will be and so i'll briefly

36:53

describe what they mean a positive

36:55

eigenvalue would mean you want to attend

36:56

to the same tokens

36:58

and imagine your eigenvalue and this is

37:00

what you would mostly see in our models

37:01

we've seen so far means you want to go

37:03

in and attend to a unrelated or

37:05

different token

37:06

and a negative eigenvalue would mean you

37:08

want to avoid attending to the same

37:09

triplet

37:11

so that will be relevant in a second

37:14

um yes so those are going to mostly be

37:16

useful to think about in in multi-layer

37:18

potentially transformers when we kind of

37:19

change the attention heads and so we can

37:21

ask you know well i'll get to that in a

37:23

second yeah so there's a table

37:24

summarizing that um unfortunately this

37:27

this approach completely breaks down

37:28

once you have mlp layers mlp layers you

37:30

know now you have have these

37:32

non-linearities and so you don't get

37:33

this property where your model is mostly

37:34

linear and you can you can just look at

37:36

a matrix but if you're working with only

37:37

attention only transformers this is a

37:38

very nice way to think about effects

37:40

okay so recall that one that you're

37:42

intentionally transformers don't undergo

37:44

this phase change that we talked about

37:45

in the beginning like right now we're on

37:46

a hunt we're trying to go and answer

37:48

this mystery of how what the hell is

37:50

going on in that phase change where

37:51

models suddenly get good at in context

37:52

learning um we want to answer that and

37:55

one layer attention only transformers

37:56

don't undergo that phase change but two

37:58

layer attenuation transformers do so

38:00

we'd like to know what's different about

38:01

two layer attention only transformers

38:06

um

38:07

okay well so in our in our previous when

38:09

we're dealing with one layer attention

38:11

transformers we're able to go and

38:12

rewrite them in this

38:13

this form and it gave us a lot of struct

38:15

ability to go and understand the model

38:17

because we could go and say well you

38:18

know this is bi-grams and then each one

38:20

of these is looking somewhere and we

38:22

have this matrix that describes how it

38:23

affects things and

38:26

and yeah so that gave us a lot of a lot

38:28

of ability to think about this thing

38:30

these things and we we can also just

38:32

write in this factored form where we

38:33

have the embedding and then we have the

38:34

attention heads and then we have the

38:35

unembedding

38:37

okay well

38:39

um

38:40

oh and for simplicity we often go and

38:42

write wov for wowv because they always

38:45

come together it's always the case like

38:47

it's it's in some sense an illusion that

38:48

w o and wv are different matrices

38:50

they're just one low rank matrix they're

38:51

never they're they're always used

38:53

together and similarly w q and w k it's

38:55

sort of an illusion that they're they're

38:56

different matrices um they're they're

38:58

always just used together and and keys

39:00

and queries are just sort of they're

39:01

just an artifact of this of these low

39:03

rank matrices

39:05

so in any case it's useful to go and

39:06

write those together

39:07

um okay great so um a two-layer

39:09

attention only transformer what we do is

39:11

we we go through the embedding matrix

39:14

then we go through the layer one

39:16

attention heads then we go through the

39:17

layer two attenuates

39:19

and then we go through the unembedding

39:20

and for the the attention is we always

39:22

have this identity as well which

39:24

corresponds just going down the residual

39:25

string so we can

39:27

uh go down the residual stream or we can

39:28

go through an attention head

39:31

next step we can also go down the

39:32

residual stream or we can go through an

39:33

attention head

39:37

um and there's this useful identity uh

39:40

the mixed product identity that um any

39:42

tensor product or or other ways of

39:44

interpreting this um obey which is that

39:47

if you have an attention head

39:49

um and we have say you know we have the

39:50

weights and the attention pattern and

39:52

the wov matrix and the attention pattern

39:54

the attention patterns multiply together

39:56

and the ov circuits multiply together

39:58

um and they behave nicely okay great so

40:02

and we can just expand out that equation

40:04

we can just take that big product we had

40:05

at the beginning we just expanded out

40:07

and we get three different kinds of

40:08

terms so one thing we do is we get this

40:10

this path that just goes directly

40:12

through the residual stream where we

40:13

embed and unembed and that's going to

40:14

want to represent some bigram statistics

40:18

um then we get things that look like

40:21

the attention head terms that we had

40:23

previously

40:26

and finally

40:28

we get these terms that correspond to

40:31

going through two attention heads

40:36

and

40:37

now it's worth noting that these terms

40:39

are not actually the same as they're

40:41

because the attention head uh the

40:42

attention patterns in the next in the

40:43

second layer can be computed from the

40:45

outputs of the first layer there those

40:47

are also going to be more expressive but

40:48

at a high level you can think of there

40:50

as being these three different kinds of

40:51

terms and we sometimes call these terms

40:53

virtual attention nets because they they

40:55

don't exist in the sun like they aren't

40:56

sort of explicitly represented in the

40:57

model but um they in fact you know they

41:00

have an attention pattern they have no

41:02

circuit they're sort of in almost all

41:04

functional ways like a tiny little

41:05

attention head and there's exponentially

41:07

many of them

41:08

um

41:09

it turns out they're not going to be

41:10

that important in this model but in

41:11

other models they can be important

41:14

um right so one one thing that i said it

41:16

allows us to think about attention in a

41:17

really principled way we don't have to

41:19

go and think about um

41:22

you know i think there's like people

41:23

people look at attention patterns all

41:25

the time and i think a concern you could

41:27

have is well you know

41:28

there's multiple attention patterns like

41:30

you know the information that's been

41:31

moved by one intention it might have

41:32

been moved there by another attention

41:33

ahead and not originally there it might

41:35

still be moved somewhere else um but in

41:38

fact this gives us a way to avoid all

41:39

those concerns and just think about

41:40

things in a single principled way

41:43

um okay in any case um an important

41:45

question to ask is how important are

41:47

these different terms well we could

41:49

study all of them how important are they

41:51

um

41:52

and it turns out um you can just

41:54

there's an algorithm you can use where

41:55

you knock out attention

41:57

knock out these terms and you go and you

41:59

ask how important are they

42:00

um and the it turns out the by far the

42:03

most important thing is these individual

42:05

attention head terms in this model far

42:07

by far the most important thing the

42:08

virtual tension heads basically

42:10

don't matter that much

42:12

they only have an effect of 0.3 not

42:14

using to be the above ones and the

42:15

bigrams are still pretty useful so if we

42:17

want to try it on channel's model we

42:19

should probably go and focus our

42:20

attention on you know the virtual

42:21

attention heads are not going to be the

42:23

best way to go in and go in our uh focus

42:26

our attention especially since there's

42:27

there's a lot of them there's 124 of

42:29

them for 0.3 knots it's very little that

42:31

you would understand for prosthetic one

42:33

of those terms

42:34

so the thing that we probably want to do

42:36

we know that these are bigram statistics

42:38

so what we really want to do is we want

42:39

to understand

42:40

the

42:41

the individual tension head terms

42:46

um this is the algorithm i'm going to

42:48

skip over it for time

42:49

we can ignore that term because it's

42:50

small

42:51

um

42:52

and it turns out also that the layer two

42:54

attention heads are doing way more than

42:56

layer one attention so that's that's not

42:58

that surprising like the layer two

43:00

intense are more expressive because they

43:01

can use the layer one attention to

43:02

construct their attention patterns

43:05

okay so uh if we could just go and

43:07

understand the layer two attention heads

43:08

we probably understand a lot of what's

43:10

going on in this model

43:12

um

43:14

and the trick is that the attention

43:15

heads are now constructed from the

43:17

previous layer rather than just from the

43:18

tokens so this is still the same but the

43:20

attention head the attention pattern is

43:22

more more complex and if you write it

43:24

out you get this complex equation that

43:26

says you know you embed the tokens then

43:28

you're going to shuffle things around

43:30

using the attention edits for the keys

43:31

then you multiply by w uk then you

43:33

multiply shuffle things around again for

43:34

the queries and then you go and multiply

43:36

by the embedding again because they were

43:38

embedded and then you get back to the

43:39

tokens um

43:42

uh

43:43

but let's actually look at them so

43:46

uh one thing that's remember that when

43:47

we see positive eigenvalues in the ob

43:49

circuit we're doing copying so one thing

43:51

we can say is well 7 out of 12 and in

43:53

fact the ones with the largest

43:54

eigenvalues um are doing coffeeing so we

43:57

still have a lot of attention they're

43:58

doing copying

44:02

um

44:04

and yeah the qk circuit so one one thing

44:06

you could do is you could try to

44:07

understand things in terms of this more

44:08

complex qk equation you could also just

44:10

try to understand what the attention

44:11

patterns are doing empirically so let's

44:13

look at one of these copying ones

44:15

um i've given it the first paragraph of

44:17

harry potter and we can just look at

44:19

where it attends

44:23

and something really happened

44:24

interesting happens so almost all the

44:25

time

44:26

we just attend back to

44:28

the first token we have this this

44:30

special token at the beginning of the

44:31

sequence

44:32

and we usually think of that as just

44:33

being um a null attention operation it's

44:35

a way for it to not do anything in fact

44:37

if you if you look the value vector is

44:38

basically zero it's just not copying any

44:40

information from that

44:42

um

44:44

but when whenever we see repeated text

44:46

something interesting happens so when we

44:47

get to mr

44:49

tries to look at and it's a little bit

44:51

weak then we get to d

44:54

and it tends to errors

44:56

that's interesting

44:57

and then we get to ers

45:00

and it tends to lean

45:03

um and so it's not attending to

45:06

the same token it's attending to the

45:08

same token

45:10

shifted one forward

45:12

well that's really interesting and

45:14

there's actually a lot of attention nets

45:15

that are doing this so here we have one

45:17

where now we hit the potter's pot and we

45:19

attend deters maybe that's the same

45:21

tension i don't remember when i was

45:22

constructing this example

45:24

um it turns out this is a super common

45:25

thing so you you go and you you look for

45:28

the previous example you shift one

45:29

forward and you're like okay well last

45:30

time i solved this this is what happened

45:32

probably the same thing's gonna happen

45:36

um and we can we can go and look at the

45:39

effect that the attention head has on

45:41

the logits most of the time it's not

45:42

affecting things but in these cases it's

45:44

able to go and predict when it's doing

45:46

us this thing of going and looking

45:47

forward to go and predict the next token

45:51

um so we call this an induction an

45:52

induction head looks for the previous

45:54

copy looks forward and says ah probably

45:56

the same thing that happened last time

45:57

it's gonna happen you can think of this

45:59

as being a nearest neighbor's it's like

46:00

an in-context nearest neighbor's

46:02

algorithm it's going and searching

46:04

through your context finding similar

46:05

things and then predicting that's what's

46:07

gonna happen next

46:10

um

46:11

the way that these actually work is uh i

46:14

mean there's actually two ways but in a

46:17

model that uses rotary attention or

46:18

something like this you only only have

46:19

one

46:20

and

46:22

you shift your key first you have

46:25

an earlier attention head shifts your

46:26

key forward once you you like take the

46:28

value of the previous token and you

46:30

embed it in your present token

46:32

and then you have your query and your

46:33

key go and look

46:35

at

46:36

uh

46:37

yeah try to go and match so you look for

46:38

the same thing

46:40

um and then you go and you predict that

46:42

whatever you saw is going to be the next

46:44

token so that's the the high-level

46:45

algorithm um sometimes you can do clever

46:48

things where actually it'll care about

46:49

multiple earlier tokens and it'll look

46:51

for like short phrases and so on so

46:52

induction heads can really vary in in

46:54

how much they of the previous context

46:56

they care about or what aspects of the

46:57

previous context they care about but

46:59

this general trick of looking for the

47:00

same thing shift forward predict that is

47:03

what induction has been

47:06

um lots of examples of this

47:08

and the cool thing is you can now

47:10

you can use the qk eigenvalues to

47:12

characterize this you can say well you

47:14

know we we're looking for the same thing

47:16

shifted by one but looking for the same

47:17

thing if you expand through the

47:18

attention notes in the right way that'll

47:19

work out

47:20

um and we're copying and so an induction

47:22

head is one which has both positive ov

47:25

eigenvalues and also positive qk

47:27

eigenvalues

47:32

um and so you can just put that on a

47:33

plot and you have your induction heads

47:36

in the corner see

47:37

here ov eigen values your qk eigenvalues

47:40

i think actually ov is this axis qk is

47:42

this one access doesn't matter um and in

47:44

the corner you have your your icon

47:46

values

47:47

or your um your induction heads

47:51

um yeah and so this seems to be uh well

47:54

okay we now have a natural hypothesis

47:55

the hypothesis is the way that that

47:57

phase change we're seeing the phase

47:59

changes this is the discovery of these

48:00

induction heads that would be um the

48:02

hypothesis uh and these are way more

48:05

effective than regular you know than

48:06

this first algorithm we had which was

48:08

just sort of blindly copy things

48:09

wherever it could be plausible now we

48:11

can go and like actually recognize

48:13

patterns and look at what happened and

48:14

break that similar things are going to

48:15

happen again that's a way better

48:16

algorithm

48:19

um

48:21

yeah so there's other attention heads

48:22

that are doing more local things i'm

48:23

gonna go and skip over that and return

48:25

to our mystery because i am running out

48:26

of time i have five more minutes okay so

48:28

what what is going on with this in

48:30

context learning well now now we've

48:31

hypothesis let's check it um so we think

48:34

it might be induction heads

48:36

um

48:37

and there's a few reasons we believe

48:39

this so one thing is going to be that

48:40

inductive uh induction heads

48:43

well okay i'll just go over to the other

48:45

so one thing you can do is you can just

48:47

ablate the attention ends

48:49

and it turns it um you can color here we

48:51

have attention heads colored by how much

48:53

they are in induction head

48:55

and this is the start of the bump this

48:57

is the end of the bump here

48:59

and we can see that they first of all

49:01

induction heads are forming like

49:02

previously we didn't have induction

49:04

heads here now they're just starting to

49:05

form here and then we have really

49:07

intense induction heads here and here

49:11

and the attention heads where if you

49:12

ablate them and you get

49:15

a uh

49:16

you get a a loss and so we're we're

49:18

lucky not lost but this this meta

49:20

learning score the difference between or

49:22

in context learning store the the

49:23

difference between the 500th token and

49:25

the 50th token

49:26

and that's all explained by induction

49:29

heads

49:30

now we actually have one induction head

49:32

that doesn't contribute to it actually

49:33

it does the opposite so that's kind of

49:34

interesting uh maybe it's doing

49:36

something shorter shorter distance um

49:38

and there's also this interesting thing

49:39

where like they all rush to be induction

49:41

heads and then they they discover only

49:43

only a few went out in the end so

49:44

there's some interesting dynamics going

49:46

on there but it really seems like in

49:47

these small models

49:49

all of in context learning is explained

49:51

by these induction hats

49:53

um okay

49:55

what about large models well in large

49:57

models it's going to be harder to go and

49:58

ask this but one thing you can do is you

49:59

can ask okay you know

50:01

we can look at our uh our induction or

50:04

our in-context learning score over time

50:06

we get this sharp phase change oh look

50:08

induction heads form at exactly the same

50:10

point in time

50:12

so that's only correlational evidence

50:14

but it's pretty suggestive correlational

50:16

evidence even especially given that we

50:17

have an obvious you know like the

50:19

obvious the fact that induction heads

50:20

should have is is this um i guess it

50:22

could be that there's other mechanisms

50:23

being discovered at the same time in

50:25

large models but it has to be in a very

50:26

small window

50:28

so

50:29

really suggest the thing that's driving

50:31

that change is in context learning

50:34

um okay so

50:35

uh

50:36

obviously induction heads can go and

50:38

copy text but a question you might ask

50:40

is you know can they can they do

50:42

translation like there's all these

50:43

amazing things that models can do that

50:45

it's not obvious and you know in context

50:47

learning um or this sort of copying

50:49

mechanism could do so i just want to

50:51

very quickly

50:52

um

50:53

look at a few fun examples

50:56

so here we have

50:58

um

50:59

an attention pattern

51:01

oh i guess i need to open lexus scopes

51:08

hmm let me try doing that again

51:10

sorry i should have thought this through

51:12

a bit more before this talk

51:14

um

51:15

chris could you zoom in a little please

51:17

yeah yeah thank you

51:20

um

51:38

okay i'm not my french isn't that great

51:41

but

51:42

um my name is christopher i'm from

51:43

canada

51:45

um what we can do here is we can look at

51:47

where this attention head attends as we

51:49

go and we do this and

51:51

um it'll become especially clear on the

51:53

second sentence so here we're on the

51:55

period

51:56

and we tend to shift

51:58

now we're on

51:59

um and jus is i in french okay now we're

52:02

on the eye and we attend to speed

52:05

uh now we're on the am and we are trends

52:08

to do which is from and then from to

52:10

canada

52:12

and so we're doing a cross-lingual

52:14

induction head which we can use for

52:16

translation um and indeed if you look at

52:19

examples this is this is where it seems

52:21

to you know it seems to be a major

52:22

driving force in the model's ability to

52:24

go um and correctly do translation

52:28

another fun example is

52:30

um i think maybe maybe the most

52:33

impressive thing about in context

52:34

learning to me has been the model's

52:36

ability to go and learn arbitrary

52:37

functions like you can just show the

52:38

model a function it can start mimicking

52:40

that function well okay

52:43

yes yeah so do these induction heads

52:45

only do kind of a look ahead copy or

52:48

like can they also do some sort of like

52:50

a complex

52:52

uh structure recognition

52:54

yeah yeah so they can they can both use

52:56

a larger context previous context and

52:59

they can copy more abstract things so

53:01

like the translation one is showing you

53:02

that they can copy rather than the

53:03

literal token a translated version so

53:05

it's what i call a soft induction head

53:07

um and yeah you can you can have them

53:10

copy similar words you can have them

53:11

look at longer contexts it can look for

53:13

more structural things um

53:15

the way that we usually characterize

53:16

them is is whether in in large models

53:19

just whether they empirically behave

53:20

like an induction head so the the the

53:22

bound the definition gets a little bit

53:23

blurry when you try to encompass these

53:25

more this sort of blurry a blurry

53:26

boundary um but yeah there seem to be a

53:28

lot of attention heads that are doing

53:30

sort of more and more abstract versions

53:33

and yeah my my favorite version is this

53:35

one that i'm about to show you which is

53:37

um used let's isolate a single one of

53:40

these which can do pattern recognition

53:42

so it can learn functions in the context

53:44

and learn how to do it so i've just made

53:46

up a nonsense function here um

53:49

we're going to encode one binary

53:50

variable with a choice of whether to do

53:53

a color or a month as the first word

53:56

then

53:57

we're gonna so we have green or june

53:58

here

53:59

um let's zoom in more

54:02

so we have color or month

54:05

and animal or fruit and then we have to

54:07

map it to either true or false

54:09

so that's our goal and it's going to be

54:10

an xor so we have a binary variable

54:12

represented in this way we do an xor

54:14

i'm

54:15

pretty confident this was never in the

54:17

training set because i just made it up

54:18

and it seems like a nonsense problem

54:20

um okay so then we can go and ask you

54:22

know can the model go and push that well

54:24

it can and it uses induction heads to do

54:26

it and what we can do is we can look at

54:28

the so we look at a colon where it's

54:30

going to go and try and predict the next

54:31

word

54:32

and for instance here

54:34

um we have april dog so it's a month and

54:37

then an animal and it should be true

54:39

and what it does is it looks for a

54:40

previous previous cases where there was

54:42

an animal

54:43

a month and then an animal especially

54:45

one where the month was the same and

54:47

goes and looks and says that it's true

54:49

and so the model can go and learn learn

54:51

a function a completely arbitrary

54:53

function

54:54

by going and doing this kind of pattern

54:55

recognition

54:57

induction head

54:59

and this to me made it a lot more

55:00

plausible but these models actually

55:04

can do

55:06

can do in context learning like the

55:07

generality of all these amazing things

55:09

we see these large language models do um

55:11

uh can be explained by inductionists we

55:14

don't know that it could be that there's

55:15

other things going on um it's very

55:16

possible that there's lots of other

55:17

things going on um but it seems seems a

55:20

lot more plausible to me than it did

55:21

when when we started

55:24

i'm conscious that i am actually over

55:26

time i mean just quickly go through

55:27

these last few slides yeah so i think

55:29

thinking that this is like an in context

55:30

in your sneakers i think is a really

55:32

useful way to think about this um other

55:34

things could absolutely be contributing

55:36

um this might explain why uh

55:38

transformers do in-context learning uh

55:41

over long contacts better than uh lstms

55:44

and lstm can't do this because it's it's

55:46

not linear in the amount of compute it

55:47

needs it's like quadratic or n log n if

55:49

it was really clever

55:50

um so transformers are lstm is

55:52

impossible to do this transformers um do

55:54

do this and actually they diverge at the

55:56

same point but if you if you look well

55:59

i can go into this in more detail after

56:01

if you want um

56:03

there's a really nice paper by marcus

56:04

hutter explaining trying to predict and

56:06

explain why we observe scaling laws and

56:08

models it's worth noting that the

56:09

arguments in this paper go exactly

56:11

through to uh this example this theory

56:14

in fact they sort of work better for the

56:16

the case of thinking about these this in

56:18

context learning with with essentially a

56:19

nearest neighbor's algorithm um than

56:21

they do in in the regular case so

56:24

um yeah uh i'm happy to answer questions

56:26

i can go into as much detail as people

56:28

want about any of this and i can also if

56:31

you send me an email sending more

56:32

information about all this um and

56:34

uh yeah and you know again this this

56:36

work is not yet published and you don't

56:38

have to keep it secret but um you know

56:40

just if you could be thoughtful about

56:41

the fact that um it's unpublished work

56:43

and probably a month or two away from

56:44

coming out um i'd be really grateful for

56:46

that uh thank you so much for your time

56:49

yeah thanks a lot chris this was a great

56:50

talk

56:53

um so i'll just open up like some

56:54

general questions and then we can do

56:56

like a round of questions from the

56:57

students

56:58

so i was very excited to know like so

57:00

what is the

57:01

like the line of work that you're

57:02

currently working on is it like

57:03

extending this uh so what do you think

57:05

is like the next things you try to do to

57:08

make it more independable what are the

57:09

next yeah

57:10

i mean i want to just reverse engineer

57:12

language models i want to figure out the

57:14

entirety of what's going on in these

57:15

language models

57:16

um

57:17

and uh

57:20

you know like

57:21

one thing that we totally don't

57:22

understand is

57:24

mlp layers um more we understand

57:27

something about them um but we we don't

57:29

really understand mlp layers very well

57:31

uh there's a lot of stuff going on in

57:32

large models that we don't understand i

57:34

want to know how models do arithmetic um

57:36

i want to know um another thing that i'm

57:38

very interested in is what's going on

57:39

when you have multiple speakers the

57:41

model can clearly represent like it has

57:42

it has like a basic theory of mind

57:44

multiple speakers in a dialogue i

57:45

understand what's going on with that um

57:47

but honestly there's just so much we

57:49

don't understand um it's really it's

57:51

sort of hard to answer the question

57:52

because there's just so much to to

57:54

figure out um and we have a lot of

57:56

different threads of research and doing

57:58

this but um yeah and

58:00

uh the interpreted team at anthropic is

58:02

just sort of

58:03

has a bunch of threads trying to go and

58:05

figure out what's going on inside these

58:06

models and sort of a similar flavor to

58:08

this of just trying to figure out how do

58:10

the parameters actually encode

58:11

algorithms and can we reverse engineer

58:13

those into into meaningful computer

58:16

programs that we can we can understand

58:18

well

58:19

uh another question that is like so you

58:20

were talking about like how

58:22

like the transformers are trying to do

58:23

metal learning inherently so it's like

58:26

and you already spent a lot of time

58:27

talking about like uh the induction hats

58:29

and like that was like interesting but

58:30

like can you formalize the sort of metal

58:32

learning algorithm they might be

58:33

learning is it like possible to say like

58:35

oh maybe this is a sort of like uh like

58:37

internal algorithm that's going that's

58:38

making them like good metal learners or

58:40

something like that

58:41

i don't know i mean i think i think so i

58:43

think that there's roughly two

58:44

algorithms one is this algorithm we saw

58:46

in the one layer model and we see it in

58:48

other models too especially early on

58:49

which is just you know try to copy you

58:51

know you saw a word probably a similar

58:53

word is going to happen uh later um look

58:55

for places that it might fit in and

58:57

increase the probability so that's

58:58

that's one thing that we see and the

59:00

other thing we see is uh induction heads

59:02

which you can just summarize as as in

59:04

context nearest neighbors basically um

59:06

and it seems you know possibly there's

59:08

other things but it seems like those two

59:10

algorithms um and you know the specific

59:12

instantiations that we are looking at uh

59:14

seem to be what's driving in context

59:16

learning that would be my present theory

59:18

yeah sounds very interesting

59:20

um yeah okay um so let's open like a

59:23

round of first two questions so yeah

59:25

feel free to go ahead

59:26

for those questions

59:33

you