WEBVTT

00:00:05.990 --> 00:00:06.629
Hi, everyone.

00:00:06.629 --> 00:00:09.327
Welcome to CS 25
Transformers United V2.

00:00:09.327 --> 00:00:11.119
This was a course that
was held at Stanford

00:00:11.119 --> 00:00:13.264
in the winter of 2023.

00:00:13.265 --> 00:00:14.839
This course is not
about robots that

00:00:14.839 --> 00:00:17.324
can transform into cars as
this picture might suggest.

00:00:17.324 --> 00:00:18.949
Rather, it's about
deep learning models

00:00:18.949 --> 00:00:21.199
that have taken
the world by storm

00:00:21.199 --> 00:00:23.439
and have revolutionized
the field of AI and others.

00:00:23.440 --> 00:00:25.190
Starting from natural
language processing,

00:00:25.190 --> 00:00:27.560
transformers have
been applied all over,

00:00:27.559 --> 00:00:30.320
computer vision, reinforcement
learning, biology, robotics,

00:00:30.320 --> 00:00:31.684
et cetera.

00:00:31.684 --> 00:00:34.100
We have an exciting set
of videos lined up for you

00:00:34.100 --> 00:00:37.719
with some truly fascinating
speakers, talks, presenting

00:00:37.719 --> 00:00:39.094
how they're applying
transformers

00:00:39.094 --> 00:00:41.494
to the research in
different fields and areas.

00:00:44.070 --> 00:00:47.700
We hope you'll enjoy and
learn from these videos.

00:00:47.700 --> 00:00:52.130
So without any further
ado, let's get started.

00:00:52.130 --> 00:00:54.760
This is a purely
introductory lecture.

00:00:54.759 --> 00:00:58.750
And we'll go into the building
blocks of transformers.

00:00:58.750 --> 00:01:03.530
So first, let's start with
introducing the instructors.

00:01:03.530 --> 00:01:06.109
So for me, I'm currently on a
temporary deferral from the PhD

00:01:06.109 --> 00:01:09.200
program, and I'm leading AI at a
robotics startup, Collaborative

00:01:09.200 --> 00:01:13.579
Robotics, that are working on
some general purpose robots,

00:01:13.579 --> 00:01:14.929
somewhat like [INAUDIBLE].

00:01:14.930 --> 00:01:18.560
And I'm very passionate about
robotics and building FSG

00:01:18.560 --> 00:01:19.513
learning algorithms.

00:01:19.513 --> 00:01:21.680
My research interests are
in reinforcement learning,

00:01:21.680 --> 00:01:23.930
computer vision, and
remodeling, and I

00:01:23.930 --> 00:01:25.820
have a bunch of
publications in robotics,

00:01:25.819 --> 00:01:28.357
autonomous driving,
and other areas.

00:01:28.358 --> 00:01:29.525
My undergrad was at Cornell.

00:01:29.525 --> 00:01:33.850
If someone is from Cornell,
so nice to [INAUDIBLE]..

00:01:33.849 --> 00:01:37.209
So I'm Stephen, currently
a first-year CS PhD here.

00:01:37.209 --> 00:01:40.609
Previously did my master's at
CMU and undergrad at Waterloo.

00:01:40.609 --> 00:01:43.540
I'm mainly into NLP research,
anything involving language

00:01:43.540 --> 00:01:45.880
and text, but more
recently, I've

00:01:45.879 --> 00:01:48.789
been getting more into computer
vision as well as [INAUDIBLE]

00:01:48.790 --> 00:01:51.520
And just some stuff I do
for fun, a lot of music

00:01:51.519 --> 00:01:52.899
stuff, mainly piano.

00:01:52.900 --> 00:01:55.600
Some self-promo of what I post
a lot on my Insta, YouTube,

00:01:55.599 --> 00:01:58.780
and TikTok, so if you
guys want to check it out.

00:01:58.780 --> 00:02:01.719
My friends and I are also
starting a Stanford piano club,

00:02:01.719 --> 00:02:04.539
so if anybody's interested,
feel free to email

00:02:04.540 --> 00:02:07.060
or DM me for details.

00:02:07.060 --> 00:02:11.530
Other than that, martial arts,
bodybuilding, and huge fan

00:02:11.530 --> 00:02:14.890
of k-dramas, anime,
and occasional gamer.

00:02:14.889 --> 00:02:18.229
[LAUGHS]

00:02:18.729 --> 00:02:19.269
OK, cool.

00:02:19.270 --> 00:02:20.710
Yeah, so my name is Rylan.

00:02:20.710 --> 00:02:21.820
Instead of talking
about myself, I just

00:02:21.819 --> 00:02:23.444
want to very briefly
say that I'm super

00:02:23.444 --> 00:02:24.789
excited to take this class.

00:02:24.789 --> 00:02:26.409
I took it the last time--
sorry-- to teach this.

00:02:26.409 --> 00:02:26.740
Excuse me.

00:02:26.740 --> 00:02:28.360
I took it the last
time I was offered.

00:02:28.360 --> 00:02:30.280
I had a bunch of fun.

00:02:30.280 --> 00:02:32.650
I thought we brought in a
really great group of speakers

00:02:32.650 --> 00:02:33.150
last time.

00:02:33.150 --> 00:02:35.287
I'm super excited
for this offering.

00:02:35.287 --> 00:02:37.120
And yeah, I'm thankful
that you're all here,

00:02:37.120 --> 00:02:39.020
and I'm looking forward to a
really fun quarter together.

00:02:39.020 --> 00:02:39.530
Thank you.

00:02:39.530 --> 00:02:42.129
Yeah, so fun fact, Rylan was
the most outspoken student

00:02:42.129 --> 00:02:43.103
last year.

00:02:43.103 --> 00:02:45.520
And so if someone wants to
become an instructor next year,

00:02:45.520 --> 00:02:46.762
you know what to do.

00:02:46.762 --> 00:02:49.954
[LAUGHTER]

00:02:50.870 --> 00:02:53.800
OK, cool.

00:02:53.800 --> 00:02:54.300
Let's see.

00:02:54.300 --> 00:02:56.510
OK, I think we
have a few minutes.

00:02:56.509 --> 00:02:59.459
So what we hope you will learn
in this class is, first of all,

00:02:59.460 --> 00:03:02.585
how do transformers
work, how they

00:03:02.585 --> 00:03:04.103
are being applied,
just beyond NLP,

00:03:04.103 --> 00:03:06.020
and nowadays, like they
are pretty [INAUDIBLE]

00:03:06.020 --> 00:03:10.290
them everywhere in
AI machine learning.

00:03:10.289 --> 00:03:12.539
And what are some new and
interesting directions

00:03:12.539 --> 00:03:14.359
of research in these topics.

00:03:17.759 --> 00:03:19.724
Cool, so this class is
just an introductory.

00:03:19.724 --> 00:03:22.215
So we're just talking about
the basics of transformers,

00:03:22.215 --> 00:03:24.930
introducing them, talking about
the self-attention mechanism

00:03:24.930 --> 00:03:26.580
on which they're founded.

00:03:26.580 --> 00:03:30.870
And we'll do a deep dive
more on models like BERT

00:03:30.870 --> 00:03:32.250
to GPT, stuff like that.

00:03:32.250 --> 00:03:35.620
So with that, happy
to get started.

00:03:35.620 --> 00:03:38.280
OK, so let me start with
presenting the attention

00:03:38.280 --> 00:03:40.539
timeline.

00:03:40.539 --> 00:03:43.239
Attention all started
with this one paper.

00:03:43.240 --> 00:03:46.270
[INAUDIBLE] by
Vaswani et al in 2017.

00:03:46.270 --> 00:03:49.450
That was the beginning
of transformers.

00:03:49.449 --> 00:03:51.489
Before that, we had
the prehistoric error,

00:03:51.490 --> 00:03:55.840
where we had models
like RNM, LSDMs,

00:03:55.840 --> 00:03:57.909
and simple attention
mechanisms that didn't work

00:03:57.909 --> 00:03:59.949
or [INAUDIBLE].

00:03:59.949 --> 00:04:02.994
Starting 2017, we saw this
explosion of transformers

00:04:02.995 --> 00:04:07.180
into NLP, where people started
using it for everything.

00:04:07.180 --> 00:04:08.680
I even heard this
quote from Google.

00:04:08.680 --> 00:04:10.597
It's like our performance
increased every time

00:04:10.597 --> 00:04:11.770
we [INAUDIBLE]

00:04:11.770 --> 00:04:13.183
[CHUCKLES]

00:04:15.069 --> 00:04:17.098
For the [INAUDIBLE]
after 2018 to 2020,

00:04:17.098 --> 00:04:18.639
we saw this explosion
of transformers

00:04:18.639 --> 00:04:23.500
into other fields like vision,
a bunch of other stuff,

00:04:23.500 --> 00:04:25.990
and like biology as a whole.

00:04:25.990 --> 00:04:28.329
And in last year,
2021 was the start

00:04:28.329 --> 00:04:31.224
of the generative era, where we
got a lot of genetic modeling,

00:04:31.225 --> 00:04:35.350
started models like
Codex, GPT, DALL-E,

00:04:35.350 --> 00:04:37.360
stable diffusions,
or a lot of things

00:04:37.360 --> 00:04:40.330
happening in genetic modeling.

00:04:40.329 --> 00:04:44.229
And we started scaling up in AI.

00:04:44.230 --> 00:04:45.490
And now, the present.

00:04:45.490 --> 00:04:49.269
So this is 2022 and
the startup in '23.

00:04:49.269 --> 00:04:53.259
And now we have models
like ChatGPT, Whisperer,

00:04:53.259 --> 00:04:54.550
a bunch of others.

00:04:54.550 --> 00:04:57.250
And we're scaling onwards
without splitting up,

00:04:57.250 --> 00:04:58.810
so that's great.

00:04:58.810 --> 00:05:01.649
So that's the future.

00:05:01.649 --> 00:05:06.939
So going more into this,
so once there were RNNs.

00:05:06.939 --> 00:05:10.829
So we had Seq2Seq
models, LSTMs, GRU.

00:05:10.829 --> 00:05:13.839
What worked there was that they
were good at encoding history,

00:05:13.839 --> 00:05:17.064
but what did not work was they
didn't encode long sequences

00:05:17.064 --> 00:05:21.649
and they were very bad
at encoding context.

00:05:21.649 --> 00:05:24.569
So consider this example.

00:05:24.569 --> 00:05:27.529
Consider trying to predict
the last word in the text,

00:05:27.529 --> 00:05:29.329
"I grew up in France,
dot, dot, dot.

00:05:29.329 --> 00:05:31.250
I speak fluent Dutch."

00:05:31.250 --> 00:05:33.740
Here, you need to understand
the context for it

00:05:33.740 --> 00:05:36.470
to predict French, and
attention mechanism

00:05:36.470 --> 00:05:39.425
is very good at that, whereas
if they're just using LSDMs,

00:05:39.425 --> 00:05:42.350
it doesn't here work that well.

00:05:42.350 --> 00:05:46.400
Another thing transformers
are good at is,

00:05:46.399 --> 00:05:50.149
more based on content, is
also context prediction

00:05:50.149 --> 00:05:52.729
is like finding attention maps.

00:05:52.730 --> 00:05:56.450
If I have something
like a word like it,

00:05:56.449 --> 00:05:57.979
what noun does it correlate to.

00:05:57.980 --> 00:06:01.759
And we can give a
property attention

00:06:01.759 --> 00:06:05.240
on one of the
possible activations.

00:06:05.240 --> 00:06:10.360
And this works better
than existing mechanisms.

00:06:10.360 --> 00:06:16.465
OK, so where we were in 2021,
we were on the verge of takeoff.

00:06:16.464 --> 00:06:18.839
We were starting to realize
the potential of transformers

00:06:18.839 --> 00:06:20.879
in different fields.

00:06:20.879 --> 00:06:23.115
We solved a lot of
long sequence problems

00:06:23.115 --> 00:06:26.340
like protein folding,
AlphaFold, offline RL.

00:06:28.860 --> 00:06:31.512
We started to see few-shots,
zero-shot generalization.

00:06:31.512 --> 00:06:34.425
We saw multimodal
tasks and applications

00:06:34.425 --> 00:06:36.300
like generating
images from language.

00:06:36.300 --> 00:06:40.997
So that was DALL-E. And
it feels like [INAUDIBLE]..

00:06:43.865 --> 00:06:45.639
And this was also a
talk on transformers

00:06:45.639 --> 00:06:48.610
that you can watch on YouTube.

00:06:48.610 --> 00:06:51.129
Yeah, cool.

00:06:51.129 --> 00:06:55.269
And this is where we were
going from 2021 to 2022,

00:06:55.269 --> 00:06:58.814
which is we have gone from
the version of [INAUDIBLE]

00:06:58.814 --> 00:07:00.564
And now, we are seeing
unique applications

00:07:00.564 --> 00:07:03.745
in audio generation,
art, music, storytelling.

00:07:03.745 --> 00:07:05.620
We are starting to see
these new capabilities

00:07:05.620 --> 00:07:08.379
like commonsense,
logical reasoning,

00:07:08.379 --> 00:07:09.879
mathematical reasoning.

00:07:09.879 --> 00:07:12.819
We are also able to now
get human enlightenment

00:07:12.819 --> 00:07:13.949
and interaction.

00:07:13.949 --> 00:07:15.699
They're able to use
reinforcement learning

00:07:15.699 --> 00:07:16.689
and human feedback.

00:07:16.689 --> 00:07:19.457
That's how ChatGPT is trained
to perform really good.

00:07:19.458 --> 00:07:21.250
We have a lot of
mechanisms for controlling

00:07:21.250 --> 00:07:24.370
toxicity bias and ethics now.

00:07:24.370 --> 00:07:26.110
And there are a
lot of also, a lot

00:07:26.110 --> 00:07:30.530
of developments in other
areas like diffusion models.

00:07:30.529 --> 00:07:33.319
Cool.

00:07:33.319 --> 00:07:35.611
So the future is a
spaceship, and we are all

00:07:35.612 --> 00:07:36.320
excited about it.

00:07:39.401 --> 00:07:40.985
And there's a lot
of more applications

00:07:40.985 --> 00:07:44.750
that we can enable,
and it'll be great

00:07:44.750 --> 00:07:47.689
if you can see
transformers also up there.

00:07:47.689 --> 00:07:49.939
One big example is video
understanding and generation.

00:07:49.939 --> 00:07:51.981
That is something that
everyone is interested in,

00:07:51.982 --> 00:07:53.900
and I'm hoping we'll
see a lot of models

00:07:53.899 --> 00:07:59.839
in this area this year,
also, finance, business.

00:07:59.839 --> 00:08:02.750
I'll be very excited to
see GPT author a novel,

00:08:02.750 --> 00:08:04.970
but we need to solve very
long sequence modeling.

00:08:04.970 --> 00:08:07.700
And most transformer
models are still

00:08:07.699 --> 00:08:09.925
limited to 4,000 tokens
or something like that.

00:08:09.925 --> 00:08:13.879
So we need to make them
generalize much more

00:08:13.879 --> 00:08:17.255
better on long sequences.

00:08:17.255 --> 00:08:19.399
We also want to have
generalized agents

00:08:19.399 --> 00:08:27.879
that can do a lot of multitask,
a multi-input predictions

00:08:27.879 --> 00:08:28.750
like Gato.

00:08:28.750 --> 00:08:31.660
And so I think we will
see more of that, too.

00:08:31.660 --> 00:08:37.240
And finally, we also want
domain specific models.

00:08:37.240 --> 00:08:39.490
So you might want
a GPT model, let's

00:08:39.490 --> 00:08:41.230
put it like maybe your health.

00:08:41.230 --> 00:08:43.129
So that could be like
a DoctorGPT model.

00:08:43.129 --> 00:08:45.100
You might have a
LawyerGPT model that's

00:08:45.100 --> 00:08:46.279
trained on only law data.

00:08:46.279 --> 00:08:49.209
So currently, we have GPT models
that are trained on everything.

00:08:49.210 --> 00:08:51.730
But we might start to see
more niche models that

00:08:51.730 --> 00:08:53.050
are good at one task.

00:08:53.049 --> 00:08:55.000
And we could have a
mixture of experts,

00:08:55.000 --> 00:08:57.190
so it's like, you
can think this is a--

00:08:57.190 --> 00:08:58.760
how you'd normally
consult an expert,

00:08:58.759 --> 00:09:00.220
you'll have expert AI models.

00:09:00.220 --> 00:09:02.887
And you can go to a different AI
model for your different needs.

00:09:05.049 --> 00:09:07.269
There are still a lot
of missing ingredients

00:09:07.269 --> 00:09:10.105
to make this all successful.

00:09:10.105 --> 00:09:12.414
The first of all
is external memory.

00:09:12.414 --> 00:09:15.144
We are already starting to
see this with the models

00:09:15.144 --> 00:09:18.519
like ChatGPT, where the
inflections are short-lived.

00:09:18.519 --> 00:09:20.710
There's no long-term
memory, and they

00:09:20.710 --> 00:09:23.410
don't have ability
to remember or store

00:09:23.409 --> 00:09:25.969
conversations for long-term.

00:09:25.970 --> 00:09:29.980
And this is something
you want to fix.

00:09:29.980 --> 00:09:32.779
Second is reducing the
computation complexity.

00:09:32.779 --> 00:09:36.159
So attention mechanism is
quadratic over the sequence

00:09:36.159 --> 00:09:37.689
length, which is slow.

00:09:37.690 --> 00:09:40.450
And we want to reduce
it and make it faster.

00:09:42.855 --> 00:09:44.230
Another thing we
want to do is we

00:09:44.230 --> 00:09:46.355
want to enhance the
controllability of these models

00:09:46.355 --> 00:09:48.759
like a lot of these
models can be stochastic.

00:09:48.759 --> 00:09:51.009
And we want to be able to
control what sort of outputs

00:09:51.009 --> 00:09:52.307
we get from them.

00:09:52.307 --> 00:09:54.100
And you might have
experienced the ChatGPT,

00:09:54.100 --> 00:09:56.913
if you just refresh, you get
different output each time.

00:09:56.913 --> 00:09:59.080
But you might want to have
a mechanism that controls

00:09:59.080 --> 00:10:01.180
what sort of things you get.

00:10:01.179 --> 00:10:04.239
And finally, we want to align
our state of art language

00:10:04.240 --> 00:10:06.200
models with how the
human brain works.

00:10:06.200 --> 00:10:09.280
And we are seeing the
surge, but we still

00:10:09.279 --> 00:10:12.009
need more research on seeing
how they can make more informed.

00:10:12.009 --> 00:10:14.460
Thank you.

00:10:14.460 --> 00:10:16.820
Great, hi.

00:10:16.820 --> 00:10:18.270
Yes, I'm excited to be here.

00:10:18.269 --> 00:10:21.079
I live very nearby, so I got
the invites to come to class.

00:10:21.080 --> 00:10:23.500
And I was like, OK,
I'll just walk over.

00:10:23.500 --> 00:10:25.375
But then I spent like
10 hours on the slides,

00:10:25.375 --> 00:10:28.250
so it wasn't as simple.

00:10:28.250 --> 00:10:30.710
So yeah, I'm going to
talk about transformers.

00:10:30.710 --> 00:10:32.620
I'm going to skip the
first two over there.

00:10:32.620 --> 00:10:34.139
I'm not going to
talk about those.

00:10:34.139 --> 00:10:36.389
We'll talk about that one
just to simplify the lecture

00:10:36.389 --> 00:10:39.379
since we don't have time.

00:10:39.379 --> 00:10:41.600
OK, so I wanted to provide
a little bit of context

00:10:41.600 --> 00:10:44.336
on why does this transformers
class even exist.

00:10:44.336 --> 00:10:45.919
So a little bit of
historical context.

00:10:45.919 --> 00:10:47.569
I feel like Bilbo over there.

00:10:47.570 --> 00:10:50.712
I joined like telling
you guys about this.

00:10:50.711 --> 00:10:52.669
I don't know if you guys
saw Lord of the Rings.

00:10:52.669 --> 00:10:56.860
And basically, I joined AI in
roughly 2012, the full course,

00:10:56.860 --> 00:10:58.009
so maybe a decade ago.

00:10:58.009 --> 00:10:59.509
And back then, you
wouldn't even say

00:10:59.509 --> 00:11:00.809
that you joined AI by the way.

00:11:00.809 --> 00:11:02.449
That was like a dirty word.

00:11:02.450 --> 00:11:04.535
Now, it's OK to talk
about, but back then, it

00:11:04.534 --> 00:11:05.659
was not even deep learning.

00:11:05.659 --> 00:11:06.500
It was machine learning.

00:11:06.500 --> 00:11:08.625
That was the term we would
use if you were serious.

00:11:08.625 --> 00:11:11.960
But now, now, AI is
OK to use, I think.

00:11:11.960 --> 00:11:13.437
So basically, do
you even realize

00:11:13.437 --> 00:11:15.019
how lucky you are
potentially entering

00:11:15.019 --> 00:11:17.419
this area in roughly 2023?

00:11:17.419 --> 00:11:20.269
So back then, in 2011 or so
when I was working specifically

00:11:20.269 --> 00:11:25.960
on computer vision, your
pipeline's looked like this.

00:11:25.960 --> 00:11:28.350
So you wanted to
classify some images,

00:11:28.350 --> 00:11:30.850
you would go to a paper, and I
think this is representative.

00:11:30.850 --> 00:11:32.932
You would have three pages
in the paper describing

00:11:32.932 --> 00:11:34.986
all kinds of a zoo,
of kitchen sink,

00:11:34.986 --> 00:11:36.819
of different kinds of
features, descriptors.

00:11:36.820 --> 00:11:38.853
And you would go
to a poster session

00:11:38.852 --> 00:11:40.269
and in computer
vision conference,

00:11:40.269 --> 00:11:41.980
and everyone would have their
favorite feature descriptor

00:11:41.980 --> 00:11:42.610
that they're proposing.

00:11:42.610 --> 00:11:44.200
And it's totally
ridiculous, and you

00:11:44.200 --> 00:11:45.550
would take notes on which
one you should incorporate

00:11:45.549 --> 00:11:48.339
into your pipeline because
you would extract all of them,

00:11:48.340 --> 00:11:49.882
and then you would
put an SVM on top.

00:11:49.881 --> 00:11:51.048
So that's what you would do.

00:11:51.048 --> 00:11:52.000
So there's two pages.

00:11:52.000 --> 00:11:54.082
Make sure you get your
[? Spar ?] SIFT histograms,

00:11:54.082 --> 00:11:56.110
your SSIMs, your color
histograms, textiles,

00:11:56.110 --> 00:11:57.340
tiny images.

00:11:57.340 --> 00:11:59.649
And don't forget the
geometry specific histograms.

00:11:59.649 --> 00:12:02.225
All of them have basically
complicated code by themselves.

00:12:02.225 --> 00:12:04.600
So you're collecting code from
everywhere and running it,

00:12:04.600 --> 00:12:06.430
and it was a total nightmare.

00:12:06.429 --> 00:12:10.989
So on top of that,
it also didn't work.

00:12:10.990 --> 00:12:11.570
[LAUGHTER]

00:12:11.570 --> 00:12:14.440
So this would be, I think,
it represents the prediction

00:12:14.440 --> 00:12:15.305
from that time.

00:12:15.304 --> 00:12:17.679
You would just get predictions
like this once in a while,

00:12:17.679 --> 00:12:19.329
and you'd be like, you
just shrug your shoulders

00:12:19.330 --> 00:12:20.955
like that just happens
once in a while.

00:12:20.955 --> 00:12:23.680
Today, you would be
looking for a bug.

00:12:23.679 --> 00:12:30.639
And worse than that,
every single chunk of AI

00:12:30.639 --> 00:12:32.866
had their own completely
separate vocabulary

00:12:32.866 --> 00:12:33.699
that they work with.

00:12:33.700 --> 00:12:36.810
So if you go to NLP
papers, those papers

00:12:36.809 --> 00:12:38.059
would be completely different.

00:12:38.059 --> 00:12:40.101
So you're reading the NLP
paper, and you're like,

00:12:40.101 --> 00:12:42.490
what is this part
of speech tagging,

00:12:42.490 --> 00:12:44.605
morphological analysis,
and tactic parsing,

00:12:44.605 --> 00:12:46.029
co-reference resolution?

00:12:46.029 --> 00:12:48.189
What is MPBTKJ?

00:12:48.190 --> 00:12:49.190
And you're confused.

00:12:49.190 --> 00:12:51.430
So the vocabulary and everything
was completely different.

00:12:51.429 --> 00:12:52.971
And you couldn't
read papers, I would

00:12:52.971 --> 00:12:55.100
say, across different areas.

00:12:55.100 --> 00:12:56.590
So now, that
changed a little bit

00:12:56.590 --> 00:13:02.379
starting 2012 when Al Krizhevsky
and colleagues basically

00:13:02.379 --> 00:13:05.439
demonstrated that if you
scale a large neural network

00:13:05.440 --> 00:13:08.460
on large data set, you can
get very strong performance.

00:13:08.460 --> 00:13:10.960
And so up till then, there was
a lot of focus on algorithms.

00:13:10.960 --> 00:13:13.330
But this showed that actually
neural nets scale very well.

00:13:13.330 --> 00:13:15.160
So you need to now worry
about compute and data,

00:13:15.159 --> 00:13:16.159
and you can scale it up.

00:13:16.159 --> 00:13:17.329
It works pretty well.

00:13:17.330 --> 00:13:19.509
And then that recipe
actually did copy paste

00:13:19.509 --> 00:13:21.519
across many areas of AI.

00:13:21.519 --> 00:13:23.740
So we start to see neural
networks pop up everywhere

00:13:23.740 --> 00:13:25.768
since 2012.

00:13:25.768 --> 00:13:28.060
So we saw them in computer
vision, and NLP, and speech,

00:13:28.059 --> 00:13:30.339
and translation in RL and so on.

00:13:30.340 --> 00:13:32.649
So everyone started to use
the same kind of modeling

00:13:32.649 --> 00:13:33.985
toolkit, modeling framework.

00:13:33.985 --> 00:13:36.610
And now when you go to NLP, and
you start reading papers there,

00:13:36.610 --> 00:13:38.710
in machine translation,
for example,

00:13:38.710 --> 00:13:40.210
this is a sequence
to sequence paper

00:13:40.210 --> 00:13:41.923
which we'll come
back to in a bit.

00:13:41.923 --> 00:13:44.090
You start to read those
papers, and you're like, OK,

00:13:44.090 --> 00:13:45.340
I can recognize these words.

00:13:45.340 --> 00:13:46.420
Like there's a neural network.

00:13:46.419 --> 00:13:47.419
There's some parameters.

00:13:47.419 --> 00:13:50.064
There's an optimizer, and
it starts to read things

00:13:50.065 --> 00:13:50.950
that you know of.

00:13:50.950 --> 00:13:54.205
So that decreased tremendously
the barrier to entry

00:13:54.205 --> 00:13:56.490
across the different areas.

00:13:56.490 --> 00:13:57.970
And then, I think,
the big deal is

00:13:57.970 --> 00:14:00.317
that when the transformer
came out in 2017,

00:14:00.317 --> 00:14:02.860
it's not even that just the tool
kits and the neural networks

00:14:02.860 --> 00:14:05.529
were similar-- there's that
literally the architectures

00:14:05.529 --> 00:14:07.480
converged to like one
architecture that you

00:14:07.480 --> 00:14:10.180
copy paste across
everything seemingly.

00:14:10.179 --> 00:14:12.724
So this was kind of an
unassuming machine translation

00:14:12.725 --> 00:14:15.100
paper at the time, proposing
to transformer architecture.

00:14:15.100 --> 00:14:17.965
But what we found since then
is that you can just basically

00:14:17.965 --> 00:14:21.710
copy paste this architecture
and use it everywhere.

00:14:21.710 --> 00:14:23.889
And what's changing is
the details of the data,

00:14:23.889 --> 00:14:26.500
and the chunking of the
data, and how you feed it in.

00:14:26.500 --> 00:14:28.085
And that's a
caricature, but it's

00:14:28.085 --> 00:14:29.960
kind of like a correct
first order statement.

00:14:29.960 --> 00:14:32.800
And so now, papers are
even more similar looking

00:14:32.799 --> 00:14:34.849
because everyone's
just using transformer.

00:14:34.850 --> 00:14:38.769
And so this convergence
was remarkable to watch

00:14:38.769 --> 00:14:40.210
and unfolded over
the last decade.

00:14:40.210 --> 00:14:42.340
And it's pretty crazy to me.

00:14:42.340 --> 00:14:44.038
What I find
interesting is I think

00:14:44.038 --> 00:14:46.330
this is some kind of a hint
that we're maybe converging

00:14:46.330 --> 00:14:48.080
to something that maybe
the brain is doing

00:14:48.080 --> 00:14:50.560
because the brain is very
homogeneous and uniform

00:14:50.559 --> 00:14:52.831
across the entire
sheet of your cortex.

00:14:52.831 --> 00:14:54.789
And OK, maybe some of
the details are changing,

00:14:54.789 --> 00:14:56.409
but those feel like
hyperparameters

00:14:56.409 --> 00:14:57.490
like a transformer.

00:14:57.490 --> 00:14:59.560
But your auditory cortex
and your visual cortex

00:14:59.559 --> 00:15:01.029
and everything else
looks very similar.

00:15:01.029 --> 00:15:02.779
And so maybe we're
converging to some kind

00:15:02.779 --> 00:15:06.100
of a uniform powerful
learning algorithm here.

00:15:06.100 --> 00:15:09.060
Something like that, I think,
is interesting and exciting.

00:15:09.059 --> 00:15:11.309
OK, so I want to talk about
where the transformer came

00:15:11.309 --> 00:15:12.771
from briefly, historically.

00:15:12.772 --> 00:15:15.430
So I want to start in 2003.

00:15:15.429 --> 00:15:17.084
I like this paper quite a bit.

00:15:17.085 --> 00:15:21.190
It was the first popular
application of neural networks

00:15:21.190 --> 00:15:22.690
to the problem of
language modeling,

00:15:22.690 --> 00:15:24.398
so predicting in this
case, the next word

00:15:24.398 --> 00:15:26.148
in the sequence, which
allows you to build

00:15:26.148 --> 00:15:27.320
generative models over text.

00:15:27.320 --> 00:15:29.695
And in this case, they were
using multi-layer perceptron,

00:15:29.695 --> 00:15:30.860
so very simple neural net.

00:15:30.860 --> 00:15:33.442
The neural nets took three words
and predicted the probability

00:15:33.442 --> 00:15:36.000
distribution for the
fourth word in a sequence.

00:15:36.000 --> 00:15:39.519
So this was well and
good at this point.

00:15:39.519 --> 00:15:41.710
Now, over time, people
started to apply this

00:15:41.710 --> 00:15:43.610
to machine translation.

00:15:43.610 --> 00:15:45.759
So that brings us to
sequence to sequence paper

00:15:45.759 --> 00:15:48.009
from 2014 that was
pretty influential,

00:15:48.009 --> 00:15:49.812
and the big problem
here was OK, we

00:15:49.812 --> 00:15:52.269
don't just want to take three
words and predict the fourth.

00:15:52.269 --> 00:15:55.329
We want to predict how to
go from an English sentence

00:15:55.330 --> 00:15:56.830
to a French sentence.

00:15:56.830 --> 00:15:58.387
And the key problem
was OK, you can

00:15:58.386 --> 00:16:00.969
have arbitrary number of words
in English and arbitrary number

00:16:00.970 --> 00:16:03.040
of words in French,
so how do you

00:16:03.039 --> 00:16:04.750
get an architecture
that can process

00:16:04.750 --> 00:16:06.820
this variably sized input?

00:16:06.820 --> 00:16:10.330
And so here they used a
LSDM, and there's basically

00:16:10.330 --> 00:16:16.160
two chunks of this, which are
covered by the slack, by this.

00:16:16.159 --> 00:16:19.009
But basically have an
encoder LSDM on the left,

00:16:19.009 --> 00:16:22.189
and it just consumes
one word at a time

00:16:22.190 --> 00:16:24.230
and builds up a context
of what it has read.

00:16:24.230 --> 00:16:26.899
And then that acts as
a conditioning vector

00:16:26.899 --> 00:16:29.019
to the decoder RNN or LSDM.

00:16:29.019 --> 00:16:30.394
That basically
goes chonk, chonk,

00:16:30.394 --> 00:16:32.299
chonk for the next
word in a sequence,

00:16:32.299 --> 00:16:35.437
translating the English to
French or something like that.

00:16:35.437 --> 00:16:37.730
Now, the big problem with
this, that people identified,

00:16:37.730 --> 00:16:40.129
I think, very quickly
and tried to resolve

00:16:40.129 --> 00:16:43.320
is that there's what's called
this encoder bottleneck.

00:16:43.320 --> 00:16:46.400
So this entire English sentence
that we are trying to condition

00:16:46.399 --> 00:16:48.289
on is packed into
a single vector

00:16:48.289 --> 00:16:50.879
that goes from the
encoder to the decoder.

00:16:50.879 --> 00:16:52.547
And so this is just
too much information

00:16:52.547 --> 00:16:54.338
to potentially maintain
in a single vector,

00:16:54.337 --> 00:16:55.579
and that didn't seem correct.

00:16:55.580 --> 00:16:57.455
And so people who are
looking around for ways

00:16:57.455 --> 00:17:00.800
to alleviate the attention of
the encoder bottleneck as it

00:17:00.799 --> 00:17:02.079
was called at the time.

00:17:02.080 --> 00:17:03.773
And so that brings
us to this paper,

00:17:03.773 --> 00:17:05.690
Neural Machine Translation
by Jointly Learning

00:17:05.690 --> 00:17:07.549
to Align and Translate.

00:17:07.549 --> 00:17:11.552
And here, just quoting from
the abstract, "in this paper,

00:17:11.553 --> 00:17:13.720
we conjectured that the use
of a fixed length vector

00:17:13.720 --> 00:17:15.553
is a bottleneck in
improving the performance

00:17:15.553 --> 00:17:17.304
of the basic
encoder-decoder architecture

00:17:17.304 --> 00:17:19.720
and propose to extend
this by allowing

00:17:19.720 --> 00:17:21.700
the model to
automatically soft search

00:17:21.700 --> 00:17:24.366
for parts of the source sentence
that are relevant to predicting

00:17:24.366 --> 00:17:28.029
a target word without
having to form

00:17:28.029 --> 00:17:30.049
these parts or hard
segments exclusively."

00:17:30.049 --> 00:17:34.690
So this was a way to look
back to the words that

00:17:34.690 --> 00:17:35.950
are coming from the encoder.

00:17:35.950 --> 00:17:38.390
And it was achieved
using this soft search.

00:17:38.390 --> 00:17:42.250
So as you are
decoding in the words

00:17:42.250 --> 00:17:44.172
here, while you
are decoding them,

00:17:44.172 --> 00:17:45.880
you are allowed to
look back at the words

00:17:45.880 --> 00:17:49.150
at the encoder via this soft
attention mechanism proposed

00:17:49.150 --> 00:17:50.180
in this paper.

00:17:50.180 --> 00:17:52.735
And so this paper, I think,
is the first time that I saw,

00:17:52.734 --> 00:17:55.689
basically, attention.

00:17:55.690 --> 00:17:58.990
So your context vector
that comes from the encoder

00:17:58.990 --> 00:18:01.150
is a weighted sum
of the hidden states

00:18:01.150 --> 00:18:05.470
of the words in the encoding.

00:18:05.470 --> 00:18:07.450
And then the weights
of this sum come

00:18:07.450 --> 00:18:10.900
from a softmax that is based
on these compatibilities

00:18:10.900 --> 00:18:13.300
between the current
state as you're decoding

00:18:13.299 --> 00:18:15.325
and the hidden states
generated by the encoder.

00:18:15.325 --> 00:18:17.200
And so this is the first
time that really you

00:18:17.200 --> 00:18:22.059
start to look at it, and this
is the current modern equations

00:18:22.059 --> 00:18:23.259
of the attention.

00:18:23.259 --> 00:18:25.509
And I think this was the
first paper that I saw it in.

00:18:25.509 --> 00:18:27.670
It's the first time
that there's a word

00:18:27.670 --> 00:18:32.029
attention used, as far as I
know, to call this mechanism.

00:18:32.029 --> 00:18:34.480
So I actually tried to dig
into the details of the history

00:18:34.480 --> 00:18:35.740
of the attention.

00:18:35.740 --> 00:18:38.518
So the first author
here, Dzmitry, I

00:18:38.518 --> 00:18:40.059
had an email
correspondence with him,

00:18:40.059 --> 00:18:41.440
and I basically
sent him an email.

00:18:41.440 --> 00:18:43.000
I'm like, Dzmitry, this
is really interesting.

00:18:43.000 --> 00:18:44.259
Just rumors have taken over.

00:18:44.259 --> 00:18:45.819
Where did you come up
with the soft attention

00:18:45.819 --> 00:18:48.309
mechanism that ends up being
the heart of the transformer?

00:18:48.309 --> 00:18:52.037
And to my surprise, he wrote me
back this massive email, which

00:18:52.037 --> 00:18:52.995
was really fascinating.

00:18:52.994 --> 00:18:54.577
So this is an excerpt
from that email.

00:18:57.119 --> 00:18:59.969
So basically, he talks about
how he was looking for a way

00:18:59.970 --> 00:19:02.490
to avoid this bottleneck
between the encoder and decoder.

00:19:02.490 --> 00:19:04.049
He had some ideas
about cursors that

00:19:04.049 --> 00:19:06.809
traverse the sequences
that didn't quite work out.

00:19:06.809 --> 00:19:08.909
And then here, "so one
day, I had this thought

00:19:08.910 --> 00:19:10.701
that it would be nice
to enable the decoder

00:19:10.701 --> 00:19:13.680
RNN to learn to search where
to put the cursor in the source

00:19:13.680 --> 00:19:14.610
sequence.

00:19:14.609 --> 00:19:16.692
This was sort of inspired
by translation exercises

00:19:16.692 --> 00:19:21.150
that learning English in
my middle school involved.

00:19:21.150 --> 00:19:23.567
Your gaze shifts back and forth
between source and target,

00:19:23.567 --> 00:19:24.692
sequence as you translate."

00:19:24.692 --> 00:19:27.150
So literally, I thought that
this was kind of interesting,

00:19:27.150 --> 00:19:28.425
that he's not a native
English speaker,

00:19:28.424 --> 00:19:31.079
and here, that gave him an edge
in this machine translation

00:19:31.079 --> 00:19:34.519
that led to attention and
then led to transformer.

00:19:34.519 --> 00:19:37.019
So that's really fascinating.

00:19:37.019 --> 00:19:38.670
"I expressed a soft
search a softmax

00:19:38.670 --> 00:19:40.920
and then weighted averaging
of the [INAUDIBLE] states.

00:19:40.920 --> 00:19:43.800
And basically, to
my great excitement,

00:19:43.799 --> 00:19:45.750
this worked from
the very first try."

00:19:45.750 --> 00:19:48.390
So really, I think,
interesting piece of history.

00:19:48.390 --> 00:19:51.030
And as it later turned out
that the name of RNN search

00:19:51.029 --> 00:19:54.059
was kind of lame, so the
better name attention came

00:19:54.059 --> 00:19:57.179
from Yoshua on one
of the final passes

00:19:57.180 --> 00:19:58.660
as they went over the paper.

00:19:58.660 --> 00:20:00.960
So maybe Attention
is All You Need

00:20:00.960 --> 00:20:03.682
would have been called RNN
Search is All You Need,

00:20:03.682 --> 00:20:05.099
but we have Yoshua
Bengio to thank

00:20:05.099 --> 00:20:07.049
for a little bit of
better name, I would say.

00:20:07.049 --> 00:20:08.940
So apparently,
that's the history

00:20:08.940 --> 00:20:11.620
of this, which I
thought was interesting.

00:20:11.619 --> 00:20:13.709
OK, so that brings us to
2017, which is Attention

00:20:13.710 --> 00:20:14.890
is All You Need.

00:20:14.890 --> 00:20:16.515
So this attention
component, which

00:20:16.515 --> 00:20:19.020
in Dzmitry's paper was
just one small segment,

00:20:19.019 --> 00:20:21.180
and there's all this
bidirectional RNN, RNN

00:20:21.180 --> 00:20:25.235
and decoder, and this Attention
All You Need paper is saying,

00:20:25.234 --> 00:20:26.944
OK, you can actually
delete everything.

00:20:26.944 --> 00:20:28.319
What's making this
work very well

00:20:28.319 --> 00:20:29.759
is just attention by itself.

00:20:29.759 --> 00:20:32.099
And so delete everything,
keep attention.

00:20:32.099 --> 00:20:35.129
And then what's remarkable about
this paper actually is usually,

00:20:35.130 --> 00:20:36.880
you see papers that
are very incremental.

00:20:36.880 --> 00:20:39.810
They add one thing, and
they show that it's better.

00:20:39.809 --> 00:20:41.309
But I feel like
Attention is All You

00:20:41.309 --> 00:20:44.099
Need was like a mix of multiple
things at the same time.

00:20:44.099 --> 00:20:46.379
They were combined
in a very unique way,

00:20:46.380 --> 00:20:49.110
and then also achieve a
very good local minimum

00:20:49.109 --> 00:20:50.554
in the architecture space.

00:20:50.555 --> 00:20:52.529
And so to me, this is
really a landmark paper

00:20:52.529 --> 00:20:55.859
that is quite
remarkable and, I think,

00:20:55.859 --> 00:20:58.649
had quite a lot of
work behind the scenes.

00:20:58.650 --> 00:21:01.380
So delete all the RNN,
just keep attention.

00:21:01.380 --> 00:21:03.562
Because attention
operates over sets--

00:21:03.561 --> 00:21:05.269
and I'm going to go
to this in a second--

00:21:05.269 --> 00:21:07.228
you now need to positionally
encode your inputs

00:21:07.228 --> 00:21:10.240
because attention doesn't have
the notion of space by itself.

00:21:14.684 --> 00:21:17.669
I have to be very careful.

00:21:17.670 --> 00:21:19.904
They adopted this
residual network structure

00:21:19.904 --> 00:21:21.450
from resonance.

00:21:21.450 --> 00:21:24.470
They interspersed attention
with multi-layer perceptrons.

00:21:24.470 --> 00:21:27.012
They used layer norms, which
came from a different paper.

00:21:27.012 --> 00:21:29.429
They introduced the concept
of multiple heads of attention

00:21:29.430 --> 00:21:30.870
that were applied in parallel.

00:21:30.869 --> 00:21:33.000
And they gave us, I think,
like a fairly good set

00:21:33.000 --> 00:21:35.279
of hyperparameters that
to this day are used.

00:21:35.279 --> 00:21:39.509
So the expansion factor in the
multi-layer perceptron goes up

00:21:39.509 --> 00:21:40.397
by 4X--

00:21:40.397 --> 00:21:41.939
and we'll go into
a bit more detail--

00:21:41.940 --> 00:21:43.230
and this 4X has stuck around.

00:21:43.230 --> 00:21:44.968
And I believe there's
a number of papers

00:21:44.968 --> 00:21:47.009
that try to play with all
kinds of little details

00:21:47.009 --> 00:21:50.730
of the transformer, and nothing
sticks because this is actually

00:21:50.730 --> 00:21:51.450
quite good.

00:21:51.450 --> 00:21:54.930
The only thing to my
knowledge that didn't stick

00:21:54.930 --> 00:21:56.820
was this reshuffling
of the layer norms

00:21:56.819 --> 00:21:59.419
to go into the prenorm
version where here you

00:21:59.420 --> 00:22:01.920
see the layer norms are after
the multiheaded attention feed

00:22:01.920 --> 00:22:02.759
forward.

00:22:02.759 --> 00:22:04.277
They just put them
before instead.

00:22:04.277 --> 00:22:06.360
So just reshuffling of
layer norms, but otherwise,

00:22:06.359 --> 00:22:08.567
the TPTs and everything else
that you're seeing today

00:22:08.567 --> 00:22:11.930
is basically the 2017
architecture from 5 years ago.

00:22:11.930 --> 00:22:13.680
And even though everyone
is working on it,

00:22:13.680 --> 00:22:15.765
it's been proven
remarkably resilient,

00:22:15.765 --> 00:22:17.280
which I think is
real interesting.

00:22:17.279 --> 00:22:18.779
There are innovations
that, I think,

00:22:18.779 --> 00:22:21.539
have been adopted also
in positional encoding.

00:22:21.539 --> 00:22:24.000
It's more common to use
different rotary and relative

00:22:24.000 --> 00:22:25.843
positional encoding and so on.

00:22:25.843 --> 00:22:28.259
So I think there have been
changes, but for the most part,

00:22:28.259 --> 00:22:31.069
it's proven very resilient.

00:22:31.069 --> 00:22:32.799
So really quite an
interesting paper.

00:22:32.799 --> 00:22:36.720
Now, I wanted to go into
the attention mechanism.

00:22:36.720 --> 00:22:43.092
And I think, the way I interpret
it is not similar to the ways

00:22:43.092 --> 00:22:44.550
that I've seen it
presented before.

00:22:44.549 --> 00:22:47.417
So let me try a different
way of how I see it.

00:22:47.417 --> 00:22:49.959
Basically, to me, attention is
kind of like the communication

00:22:49.960 --> 00:22:52.210
phase of the transformer,
and the transformer

00:22:52.210 --> 00:22:55.616
interweaves two phases of the
communication phase, which

00:22:55.616 --> 00:22:57.700
is the multi-headed
attention, and the computation

00:22:57.700 --> 00:23:00.069
stage, which is this
multilayered perceptron

00:23:00.069 --> 00:23:01.539
or [INAUDIBLE].

00:23:01.539 --> 00:23:03.670
So in the communication
phase, it's

00:23:03.670 --> 00:23:05.170
really just a data
dependent message

00:23:05.170 --> 00:23:07.279
passing on directed graphs.

00:23:07.279 --> 00:23:09.279
And you can think of it
as OK, forget everything

00:23:09.279 --> 00:23:10.960
with machine
translation, everything.

00:23:10.960 --> 00:23:13.120
Let's just-- we have
directed graphs.

00:23:13.119 --> 00:23:16.000
At each node, you
are storing a vector.

00:23:16.000 --> 00:23:18.714
And then let me talk now
about the communication

00:23:18.714 --> 00:23:20.589
phase of how these
vectors talk to each other

00:23:20.589 --> 00:23:21.309
and this directed graph.

00:23:21.309 --> 00:23:23.230
And then the compute
phase later is just

00:23:23.230 --> 00:23:27.700
a multi-perceptron, which then
basically acts on every node

00:23:27.700 --> 00:23:28.932
individually.

00:23:28.932 --> 00:23:30.640
But how do these nodes
talk to each other

00:23:30.640 --> 00:23:32.930
in this directed graph?

00:23:32.930 --> 00:23:36.759
So I wrote like
some simple Python--

00:23:36.759 --> 00:23:39.339
I wrote this in Python
basically to create

00:23:39.339 --> 00:23:44.049
one round of communication
of using attention

00:23:44.049 --> 00:23:46.549
as the message passing scheme.

00:23:46.549 --> 00:23:51.204
So here, a node has this
private data vector,

00:23:51.204 --> 00:23:53.079
as you can think of it
as private information

00:23:53.079 --> 00:23:54.069
to this node.

00:23:54.069 --> 00:23:57.309
And then it can also emit a
key, a query, and a value.

00:23:57.309 --> 00:24:00.399
And simply, that's done
by linear transformation

00:24:00.400 --> 00:24:01.310
from this node.

00:24:01.309 --> 00:24:07.220
So the key is what are
the things that I am--

00:24:07.220 --> 00:24:07.720
sorry.

00:24:07.720 --> 00:24:10.214
The query is what are the
things that I'm looking for?

00:24:10.214 --> 00:24:12.089
The key is what other
the things that I have?

00:24:12.089 --> 00:24:15.049
And the value is what are the
things that I will communicate?

00:24:15.049 --> 00:24:16.849
And so then when you
have your graph that's

00:24:16.849 --> 00:24:19.254
made up of nodes in some
random edges, when you actually

00:24:19.255 --> 00:24:21.380
have these nodes communicating,
what's happening is

00:24:21.380 --> 00:24:23.536
you loop over all the
nodes individually

00:24:23.536 --> 00:24:27.110
in some random order,
and you're at some node,

00:24:27.109 --> 00:24:29.240
and you get the
query vector q, which

00:24:29.240 --> 00:24:32.595
is, I'm a node in
some graph, and this

00:24:32.595 --> 00:24:33.595
is what I'm looking for.

00:24:33.595 --> 00:24:36.011
And so that's just achieved
via this linear transformation

00:24:36.011 --> 00:24:36.859
here.

00:24:36.859 --> 00:24:39.716
And then we look at all the
inputs that point to this node,

00:24:39.717 --> 00:24:42.050
and then they broadcast what
are the things that I have,

00:24:42.049 --> 00:24:44.029
which is their keys.

00:24:44.029 --> 00:24:45.680
So they broadcast the keys.

00:24:45.680 --> 00:24:49.279
I have the query, then those
interact by dot product

00:24:49.279 --> 00:24:51.210
to get scores.

00:24:51.210 --> 00:24:53.120
So basically, simply
by doing dot product,

00:24:53.119 --> 00:24:55.669
you get some
unnormalized weighting

00:24:55.670 --> 00:24:59.870
of the interestingness of all
of the information in the nodes

00:24:59.869 --> 00:25:02.002
that point to me and to
the things I'm looking for.

00:25:02.002 --> 00:25:03.919
And then when you normalize
that with softmax,

00:25:03.920 --> 00:25:06.743
so it just sums to
1, you basically just

00:25:06.742 --> 00:25:09.409
end up using those scores, which
now sum to 1 in our probability

00:25:09.410 --> 00:25:13.279
distribution, and you do a
weighted sum of the values

00:25:13.279 --> 00:25:15.079
to get your update.

00:25:15.079 --> 00:25:17.329
So I have a query.

00:25:17.329 --> 00:25:21.500
They have keys, dot products
to get interestingness or like

00:25:21.500 --> 00:25:24.170
affinity, softmax to
normalize it, and then

00:25:24.170 --> 00:25:27.398
weighted sum of those values
flow to me and update me.

00:25:27.397 --> 00:25:29.439
And this is happening for
each node individually.

00:25:29.440 --> 00:25:30.707
And then we update at the end.

00:25:30.707 --> 00:25:32.540
And so this kind of a
message passing scheme

00:25:32.539 --> 00:25:35.990
is at the heart of
the transformer.

00:25:35.990 --> 00:25:40.204
And it happens in the more
vectorized batched way

00:25:40.204 --> 00:25:44.210
that is more confusing and is
also interspersed with layer

00:25:44.210 --> 00:25:46.640
norms and things like that
to make the training behave

00:25:46.640 --> 00:25:47.473
better.

00:25:47.472 --> 00:25:49.639
But that's roughly what's
happening in the attention

00:25:49.640 --> 00:25:51.140
mechanism, I think,
on a high level.

00:25:53.720 --> 00:25:59.029
So yeah, so in the communication
phase of the transformer, then

00:25:59.029 --> 00:26:00.785
this message passing
scheme happens

00:26:00.785 --> 00:26:06.490
in every head in parallel and
then in every layer in series

00:26:06.490 --> 00:26:08.529
and with different
weights each time.

00:26:08.529 --> 00:26:13.149
And that's it as far as the
multi-headed attention goes.

00:26:13.150 --> 00:26:15.790
And so if you look at these
encooder-decoder models,

00:26:15.789 --> 00:26:18.042
you can think of it then in
terms of the connectivity

00:26:18.042 --> 00:26:19.209
of these nodes in the graph.

00:26:19.210 --> 00:26:21.920
You can think of it as like,
OK, all these tokens that

00:26:21.920 --> 00:26:23.920
are in the encoder that
we want to condition on,

00:26:23.920 --> 00:26:25.600
they are fully
connected to each other.

00:26:25.599 --> 00:26:28.329
So when they communicate,
they communicate fully

00:26:28.329 --> 00:26:30.589
when you calculate
their features.

00:26:30.589 --> 00:26:32.139
But in the decoder,
because we are

00:26:32.140 --> 00:26:33.627
trying to have a
language model, we

00:26:33.626 --> 00:26:35.710
don't want to have
communication for future tokens

00:26:35.710 --> 00:26:38.170
because they give away
the answer at this step.

00:26:38.170 --> 00:26:40.810
So the tokens in the
decoder are fully connected

00:26:40.809 --> 00:26:43.644
from all the encoder
states, and then they

00:26:43.644 --> 00:26:46.575
are also fully connected from
everything that is decoding.

00:26:46.575 --> 00:26:49.150
And so you end up with
this triangular structure

00:26:49.150 --> 00:26:50.560
in the data graph.

00:26:50.559 --> 00:26:52.359
But that's the
message passing scheme

00:26:52.359 --> 00:26:54.814
that this basically implements.

00:26:54.815 --> 00:26:57.190
And then you have to be also
a little bit careful because

00:26:57.190 --> 00:26:59.065
in the cross attention
here with the decoder,

00:26:59.065 --> 00:27:01.620
you consume the features
from the top of the encoder.

00:27:01.619 --> 00:27:03.952
So think of it as
in the encoder,

00:27:03.952 --> 00:27:05.619
all the nodes are
looking at each other,

00:27:05.619 --> 00:27:08.319
all the tokens are looking at
each other many, many times.

00:27:08.319 --> 00:27:09.759
And they really figure
out what's in there,

00:27:09.759 --> 00:27:12.301
and then the decoder when it's
looking only at the top nodes.

00:27:14.875 --> 00:27:16.750
So that's roughly the
message passing scheme.

00:27:16.750 --> 00:27:18.750
I was going to go into
more of an implementation

00:27:18.750 --> 00:27:19.660
of a transformer.

00:27:19.660 --> 00:27:23.125
I don't know if there's
any questions about this.

00:27:23.125 --> 00:27:26.434
[INAUDIBLE] self-attention
and multi-headed attention,

00:27:26.434 --> 00:27:30.419
but what is the
advantage of [INAUDIBLE]??

00:27:30.420 --> 00:27:35.370
Yeah, so self-attention and
multi-headed attention, so

00:27:35.369 --> 00:27:38.000
the multi-headed attention is
just this attention scheme,

00:27:38.000 --> 00:27:40.717
but it's just applied
multiple times in parallel.

00:27:40.717 --> 00:27:42.800
Multiple heads just means
independent applications

00:27:42.799 --> 00:27:44.970
of the same attention.

00:27:44.970 --> 00:27:47.990
So this message passing
scheme basically just

00:27:47.990 --> 00:27:49.970
happens in parallel
multiple times

00:27:49.970 --> 00:27:52.940
with different weights for
the query, key, and value.

00:27:52.940 --> 00:27:55.130
So you can almost look at
it like in parallel, I'm

00:27:55.130 --> 00:27:57.422
looking for, I'm seeking
different kinds of information

00:27:57.422 --> 00:27:59.029
from different nodes.

00:27:59.029 --> 00:28:01.024
And I'm collecting it
all in the same node.

00:28:01.025 --> 00:28:03.390
It's all done in parallel.

00:28:03.390 --> 00:28:06.980
So heads is really just
copy-paste in parallel.

00:28:06.980 --> 00:28:12.682
And layers are
copy-paste but in series.

00:28:12.682 --> 00:28:15.940
Maybe that makes sense.

00:28:15.940 --> 00:28:18.610
And self-attention, when
it's self-attention,

00:28:18.609 --> 00:28:21.699
what it's referring to
is that the node here

00:28:21.700 --> 00:28:23.055
produces each node here.

00:28:23.055 --> 00:28:25.632
So as I described it here,
this is really self-attention

00:28:25.632 --> 00:28:27.340
because every one of
these nodes produces

00:28:27.339 --> 00:28:30.429
a key query and a value
from this individual node.

00:28:30.430 --> 00:28:33.850
When you have cross-attention,
you have one cross-attention

00:28:33.849 --> 00:28:36.929
here, coming from the encoder.

00:28:36.930 --> 00:28:38.680
That just means that
the queries are still

00:28:38.680 --> 00:28:42.400
produced from this node,
but the keys and the values

00:28:42.400 --> 00:28:44.920
are produced as a
function of nodes that

00:28:44.920 --> 00:28:48.130
are coming from the encoder.

00:28:48.130 --> 00:28:52.050
So I have my queries because
I'm trying to decode some--

00:28:52.049 --> 00:28:53.932
the fifth word in the sequence.

00:28:53.932 --> 00:28:55.349
And I'm looking
for certain things

00:28:55.349 --> 00:28:56.759
because I'm the fifth word.

00:28:56.759 --> 00:28:58.769
And then the keys and
the values in terms

00:28:58.769 --> 00:29:01.349
of the source of information
that could answer my queries

00:29:01.349 --> 00:29:04.019
can come from the previous
nodes in the current decoding

00:29:04.019 --> 00:29:06.670
sequence or from the
top of the encoder.

00:29:06.670 --> 00:29:09.240
So all the nodes that
have already seen all

00:29:09.240 --> 00:29:12.120
of the encoding tokens many,
many times cannot broadcast

00:29:12.119 --> 00:29:14.319
what they contain in
terms of information.

00:29:14.319 --> 00:29:18.652
So I guess, to summarize,
the self-attention is--

00:29:18.652 --> 00:29:20.360
sorry, cross-attention
and self-attention

00:29:20.359 --> 00:29:24.199
only differ in where the piece
and the values come from.

00:29:24.200 --> 00:29:28.130
Either the keys and values
are produced from this node,

00:29:28.130 --> 00:29:31.340
or they are produced from some
external source like an encoder

00:29:31.339 --> 00:29:33.199
and the nodes over there.

00:29:33.200 --> 00:29:39.000
But algorithmically, is the
same mathematical operations.

00:29:39.000 --> 00:29:39.961
Question.

00:29:39.961 --> 00:29:40.599
Yeah, OK.

00:29:40.599 --> 00:29:41.899
So two questions for you.

00:29:41.900 --> 00:29:48.690
First question is, in the
message passing [INAUDIBLE]

00:29:56.690 --> 00:30:00.799
So think of-- so each one
of these nodes is a token.

00:30:04.067 --> 00:30:06.109
I guess they don't have
a very good picture of it

00:30:06.109 --> 00:30:06.901
in the transformer.

00:30:06.902 --> 00:30:14.930
But this node here could
represent the third word

00:30:14.930 --> 00:30:19.505
in the output in the decoder,
and in the beginning,

00:30:19.505 --> 00:30:21.290
it is just the
embedding of the word.

00:30:27.119 --> 00:30:30.669
And then, OK, I have to
think through this analogy

00:30:30.670 --> 00:30:31.420
a little bit more.

00:30:31.420 --> 00:30:32.711
I came up with it this morning.

00:30:32.711 --> 00:30:34.400
[LAUGHTER]

00:30:34.400 --> 00:30:35.830
[INAUDIBLE]

00:30:39.940 --> 00:30:45.865
What example of instantiation
[INAUDIBLE] nodes

00:30:45.865 --> 00:30:50.299
as in in blocks were embedding?

00:30:50.299 --> 00:30:53.201
These nodes are
basically the vectors.

00:30:53.201 --> 00:30:54.410
I'll go to an implementation.

00:30:54.410 --> 00:30:56.493
I'll go to the implementation,
and then maybe I'll

00:30:56.492 --> 00:30:58.779
make the connections
to the graph.

00:30:58.779 --> 00:31:01.490
So let me try to first
go to-- let me now go to,

00:31:01.490 --> 00:31:03.259
with this intuition
in mind, at least,

00:31:03.259 --> 00:31:05.259
to a nanoGPT, which is a
concrete implementation

00:31:05.259 --> 00:31:06.980
of a transformer
that is very minimal.

00:31:06.980 --> 00:31:08.839
So I worked on this
over the last few days,

00:31:08.839 --> 00:31:11.737
and here it is reproducing
GPT-2 on open web text.

00:31:11.738 --> 00:31:14.029
So it's a pretty serious
implementation that reproduces

00:31:14.029 --> 00:31:17.869
GPT-2, I would say, and
provide it enough compute--

00:31:17.869 --> 00:31:21.211
This was one node of 8 GPUs
for 38 hours or something

00:31:21.211 --> 00:31:22.670
like that, if I
remember correctly.

00:31:22.670 --> 00:31:23.910
And it's very readable.

00:31:23.910 --> 00:31:27.170
It's 300 lines, so everyone
can take a look at it.

00:31:27.170 --> 00:31:30.622
And yeah, let me basically
briefly step through it.

00:31:30.622 --> 00:31:34.077
So let's try to have a
decoder-only transformer.

00:31:34.077 --> 00:31:36.119
So what that means is that
it's a language model.

00:31:36.119 --> 00:31:39.936
It tries to model the
next word in the sequence

00:31:39.936 --> 00:31:41.519
or the next character
in the sequence.

00:31:41.519 --> 00:31:43.079
So the data that
we train on this

00:31:43.079 --> 00:31:44.309
is always some kind of text.

00:31:44.309 --> 00:31:45.856
So here's some fake Shakespeare.

00:31:45.856 --> 00:31:47.190
Sorry, this is real Shakespeare.

00:31:47.190 --> 00:31:48.600
We're going to produce
fake Shakespeare.

00:31:48.599 --> 00:31:50.099
So this is called
a Tiny Shakespeare

00:31:50.099 --> 00:31:52.346
dataset, which is one of
my favorite toy datasets.

00:31:52.346 --> 00:31:54.180
You take all of
Shakespeare, concatenate it,

00:31:54.180 --> 00:31:55.650
and it's 1 megabyte
file, and then

00:31:55.650 --> 00:31:56.850
you can train
language models on it

00:31:56.849 --> 00:31:58.439
and get infinite
Shakespeare, if you like,

00:31:58.440 --> 00:31:59.690
which I think is kind of cool.

00:31:59.690 --> 00:32:00.761
So we have a text.

00:32:00.761 --> 00:32:02.220
The first thing we
need to do is we

00:32:02.220 --> 00:32:05.160
need to convert it to
a sequence of integers

00:32:05.160 --> 00:32:09.120
because transformers
natively process--

00:32:09.119 --> 00:32:10.661
you can't plug text
into transformer.

00:32:10.662 --> 00:32:11.912
You need to somehow encode it.

00:32:11.912 --> 00:32:13.380
So the way that
encoding is done is

00:32:13.380 --> 00:32:15.390
we convert, for example,
in the simplest case,

00:32:15.390 --> 00:32:18.810
every character gets an
integer, and then instead of "hi

00:32:18.809 --> 00:32:21.799
there," we would have
this sequence of integers.

00:32:21.799 --> 00:32:25.490
So then you can encode every
single character as an integer

00:32:25.490 --> 00:32:27.529
and get a massive
sequence of integers.

00:32:27.529 --> 00:32:29.089
You just concatenate
it all into one

00:32:29.089 --> 00:32:31.419
large, long
one-dimensional sequence.

00:32:31.420 --> 00:32:32.750
And then you can train on it.

00:32:32.750 --> 00:32:34.563
Now, here, we only
have a single document.

00:32:34.563 --> 00:32:36.980
In some cases, if you have
multiple independent documents,

00:32:36.980 --> 00:32:38.914
what people like to do
is create special tokens,

00:32:38.914 --> 00:32:40.414
and they intersperse
those documents

00:32:40.414 --> 00:32:42.500
with those special
end of text tokens

00:32:42.500 --> 00:32:46.160
that they splice in between
to create boundaries.

00:32:46.160 --> 00:32:50.860
But those boundaries actually
don't have any modeling impact.

00:32:50.859 --> 00:32:52.609
It's just that the
transformer is supposed

00:32:52.609 --> 00:32:55.849
to learn via backpropagation
that the end of document

00:32:55.849 --> 00:33:00.019
sequence means that you
should wipe the memory.

00:33:00.019 --> 00:33:02.000
OK, so then we produce batches.

00:33:02.000 --> 00:33:04.339
So these batches
of data just mean

00:33:04.339 --> 00:33:06.379
that we go back to the
one-dimensional sequence,

00:33:06.380 --> 00:33:08.780
and we take out chunks
of this sequence.

00:33:08.779 --> 00:33:13.774
So say, if the block size is 8,
Then the block size indicates

00:33:13.775 --> 00:33:17.750
the maximum length of context
that your transformer will

00:33:17.750 --> 00:33:18.295
process.

00:33:18.295 --> 00:33:20.509
So if our block size
is 8, that means

00:33:20.509 --> 00:33:23.720
that we are going to have up
to eight characters of context

00:33:23.720 --> 00:33:26.630
to predict the ninth
character in a sequence.

00:33:26.630 --> 00:33:29.120
And the batch size indicates
how many sequences in parallel

00:33:29.119 --> 00:33:30.119
we're going to process.

00:33:30.119 --> 00:33:31.879
And we want this to be
as large as possible,

00:33:31.880 --> 00:33:33.650
so we're fully taking
advantage of the GPU

00:33:33.650 --> 00:33:36.540
and the parallels [INAUDIBLE]
So in this example,

00:33:36.539 --> 00:33:38.000
we're doing a 4 by 8 batches.

00:33:38.000 --> 00:33:41.390
So every row here is
independent example

00:33:41.390 --> 00:33:47.412
and then every row here is a
small chunk of the sequence

00:33:47.412 --> 00:33:48.620
that we're going to train on.

00:33:48.619 --> 00:33:50.619
And then we have both the
inputs and the targets

00:33:50.619 --> 00:33:52.579
at every single point here.

00:33:52.579 --> 00:33:55.159
So to fully spell out what's
contained in a single 4

00:33:55.160 --> 00:33:57.320
by 8 batch to the transformer--

00:33:57.319 --> 00:33:59.109
I sort of compact it here--

00:33:59.109 --> 00:34:04.669
so when the input is 47, by
itself, the target is 58.

00:34:04.670 --> 00:34:07.279
And when the input is
the sequence 47, 58,

00:34:07.279 --> 00:34:08.929
the target is one.

00:34:08.929 --> 00:34:13.019
And when it's 47, 58, 1,
the target is 51 and so on.

00:34:13.019 --> 00:34:15.679
So actually, the single batch
of examples that score by 8

00:34:15.679 --> 00:34:17.490
actually has a ton of
individual examples

00:34:17.490 --> 00:34:18.949
that we are expecting
a transformer

00:34:18.949 --> 00:34:21.863
to learn on in parallel.

00:34:21.862 --> 00:34:23.779
And so you'll see that
the batches are learned

00:34:23.780 --> 00:34:28.459
on completely independently, but
the time dimension here along

00:34:28.458 --> 00:34:30.948
horizontally is also
trained on in parallel.

00:34:30.949 --> 00:34:34.309
So your real batch size
is more like B times T.

00:34:34.309 --> 00:34:37.340
And it's just that the
context grows linearly

00:34:37.340 --> 00:34:41.329
for the predictions that you
make along the T direction

00:34:41.329 --> 00:34:42.509
in the model.

00:34:42.510 --> 00:34:45.664
So this is all the examples
that the model will learn from,

00:34:45.664 --> 00:34:48.830
this single batch.

00:34:48.829 --> 00:34:52.768
So now, this is the GPT class.

00:34:52.768 --> 00:34:55.946
And because this is
a decoder-only model,

00:34:55.947 --> 00:34:58.280
so we're not going to have
an encoder because there's no

00:34:58.280 --> 00:34:59.952
English we're translating from--

00:34:59.952 --> 00:35:02.119
we're not trying to condition
in some other external

00:35:02.119 --> 00:35:02.779
information.

00:35:02.780 --> 00:35:05.510
We're just trying to produce
a sequence of words that

00:35:05.510 --> 00:35:08.090
follow each other or likely to.

00:35:08.090 --> 00:35:10.658
So this is all PyTorch, and
I'm going slightly faster

00:35:10.657 --> 00:35:12.949
because I'm assuming people
have taken 231 or something

00:35:12.949 --> 00:35:15.210
along those lines.

00:35:15.210 --> 00:35:19.190
But here in the forward
pass, we take these indices,

00:35:19.190 --> 00:35:24.500
and then we both encode the
identity of the indices,

00:35:24.500 --> 00:35:26.789
just via an embedding
lookup table.

00:35:26.789 --> 00:35:31.190
So every single integer, we
index into a lookup table of

00:35:31.190 --> 00:35:34.460
vectors in this, and end
up embedding, and pull out

00:35:34.460 --> 00:35:38.099
the word vector for that token.

00:35:38.099 --> 00:35:41.431
And then because the
transformer by itself

00:35:41.431 --> 00:35:43.389
doesn't actually-- the
process is set natively.

00:35:43.389 --> 00:35:45.742
So we need to also positionally
encode these vectors

00:35:45.742 --> 00:35:47.659
so that we basically
have both the information

00:35:47.659 --> 00:35:51.679
about the token identity and
its place in the sequence from 1

00:35:51.679 --> 00:35:53.869
to block size.

00:35:53.869 --> 00:35:56.659
Now, the information
about what and where

00:35:56.659 --> 00:35:58.879
is combined additively,
so the token embeddings

00:35:58.880 --> 00:36:02.750
and the positional embeddings
are just added exactly as here.

00:36:02.750 --> 00:36:06.800
So then there's
optional dropout,

00:36:06.800 --> 00:36:08.780
this x here basically
just contains

00:36:08.780 --> 00:36:14.870
the set of words
and their positions,

00:36:14.869 --> 00:36:16.786
and that feeds into the
blocks of transformer.

00:36:16.786 --> 00:36:18.744
And we're going to look
into what's block here.

00:36:18.744 --> 00:36:20.599
But for here, for now,
this is just a series

00:36:20.599 --> 00:36:22.239
of blocks in a transformer.

00:36:22.239 --> 00:36:23.989
And then in the end,
there's a layer norm,

00:36:23.989 --> 00:36:26.799
and then you're
decoding the logits

00:36:26.800 --> 00:36:30.680
for the next word or next
integer in a sequence,

00:36:30.679 --> 00:36:33.469
using the linear projection of
the output of this transformer

00:36:33.469 --> 00:36:36.859
So LM head here, a short
core language model head.

00:36:36.860 --> 00:36:38.945
It's just a linear function.

00:36:38.945 --> 00:36:42.710
So basically, positionally
encode all the words,

00:36:42.710 --> 00:36:45.230
feed them into a
sequence of blocks,

00:36:45.230 --> 00:36:47.690
and then apply a linear
layer to get the probability

00:36:47.690 --> 00:36:50.336
distribution for
the next character.

00:36:50.336 --> 00:36:51.920
And then if we have
the targets, which

00:36:51.920 --> 00:36:54.057
we produced in the data order--

00:36:54.057 --> 00:36:55.849
and you'll notice that
the targets are just

00:36:55.849 --> 00:36:59.297
the inputs offset
by one in time--

00:36:59.297 --> 00:37:01.380
then those targets feed
into a cross entropy loss.

00:37:01.380 --> 00:37:03.088
So this is just a
negative log likelihood

00:37:03.088 --> 00:37:04.705
typical classification loss.

00:37:04.704 --> 00:37:08.840
So now let's drill into
what's here in the blocks.

00:37:08.840 --> 00:37:11.470
So these blocks that are
applied sequentially,

00:37:11.469 --> 00:37:13.469
there's, again, as I
mentioned, this communicate

00:37:13.469 --> 00:37:15.000
phase and the compute phase.

00:37:15.000 --> 00:37:17.135
So in the communicate
phase, all the nodes

00:37:17.135 --> 00:37:21.260
get to talk to each other, and
so these nodes are basically,

00:37:21.260 --> 00:37:23.900
if our block size
is 8, then we are

00:37:23.900 --> 00:37:26.405
going to have eight
nodes in this graph.

00:37:26.405 --> 00:37:28.010
There's eight nodes
in this graph.

00:37:28.010 --> 00:37:30.250
The first node is pointed
to only by itself.

00:37:30.250 --> 00:37:33.324
The second node is pointed to
by the first node and itself.

00:37:33.324 --> 00:37:35.449
The third node is pointed
to by the first two nodes

00:37:35.449 --> 00:37:36.639
and itself, et cetera.

00:37:36.639 --> 00:37:38.940
So there's eight nodes here.

00:37:38.940 --> 00:37:42.472
So you apply-- there's a
residual pathway and x.

00:37:42.472 --> 00:37:43.139
You take it out.

00:37:43.139 --> 00:37:45.449
You apply a layer norm,
and then the self-attention

00:37:45.449 --> 00:37:47.879
so that these communicate,
these eight nodes communicate.

00:37:47.880 --> 00:37:50.220
But you have to keep in
mind that the batch is 4.

00:37:50.219 --> 00:37:54.179
So because batch is 4,
this is also applied--

00:37:54.179 --> 00:37:55.859
so we have eight
nodes communicating,

00:37:55.860 --> 00:37:58.443
but there's a batch of four of
them individually communicating

00:37:58.443 --> 00:37:59.880
in one of those eight nodes.

00:37:59.880 --> 00:38:02.380
There's no crisscross across
the batch dimension, of course.

00:38:02.380 --> 00:38:04.680
There's no batch
anywhere luckily.

00:38:04.679 --> 00:38:06.809
And then once they've
changed information,

00:38:06.809 --> 00:38:09.630
they are processed using
the multi-layer perceptron.

00:38:09.630 --> 00:38:12.630
And that's the compute phase.

00:38:12.630 --> 00:38:18.137
And then also here we are
missing the cross-attention

00:38:18.137 --> 00:38:19.679
because this is a
decoder-only model.

00:38:19.679 --> 00:38:21.277
So all we have is
this step here,

00:38:21.277 --> 00:38:22.860
the multi-headed
attention, and that's

00:38:22.860 --> 00:38:24.579
this line, the
communicate phase.

00:38:24.579 --> 00:38:27.119
And then we have the feed
forward, which is the MLP,

00:38:27.119 --> 00:38:29.710
and that's the compute phase.

00:38:29.710 --> 00:38:31.610
I'll take question's
a bit later.

00:38:31.610 --> 00:38:34.745
Then the MLP here is
fairly straightforward.

00:38:34.744 --> 00:38:38.069
The MLP is just individual
processing on each node,

00:38:38.070 --> 00:38:41.530
just transforming the feature
representation at that node.

00:38:41.530 --> 00:38:45.120
So applying a
two-layer neural net

00:38:45.119 --> 00:38:47.204
with a GELU nonlinearity,
which is just

00:38:47.204 --> 00:38:49.079
think of it as a ReLU
or something like that.

00:38:49.079 --> 00:38:51.400
It's just a nonlinearity.

00:38:51.400 --> 00:38:53.610
And then MLP is straightforward.

00:38:53.610 --> 00:38:55.760
I don't think there's
anything too crazy there.

00:38:55.760 --> 00:38:57.760
And then this is the
causal self-attention part,

00:38:57.760 --> 00:38:59.750
the communication phase.

00:38:59.750 --> 00:39:01.539
So this is like
the meat of things

00:39:01.539 --> 00:39:03.670
and the most complicated part.

00:39:03.670 --> 00:39:06.599
It's only complicated
because of the batching

00:39:06.599 --> 00:39:10.349
and the implementation detail
of how you mask the connectivity

00:39:10.349 --> 00:39:13.619
in the graph so that
you can't obtain

00:39:13.619 --> 00:39:15.119
any information
from the future when

00:39:15.119 --> 00:39:16.327
you're predicting your token.

00:39:16.327 --> 00:39:18.429
Otherwise, it gives
away the information.

00:39:18.429 --> 00:39:23.099
So if I'm the fifth token and
if I'm the fifth position,

00:39:23.099 --> 00:39:26.279
then I'm getting the fourth
token coming into the input,

00:39:26.280 --> 00:39:29.010
and I'm attending to the
third, second, and first,

00:39:29.010 --> 00:39:32.160
and I'm trying to figure
out what is the next token.

00:39:32.159 --> 00:39:34.589
Well then, in this batch,
in the next element

00:39:34.590 --> 00:39:37.050
over in the time dimension,
the answer is at the input.

00:39:37.050 --> 00:39:40.360
So I can't get any
information from there.

00:39:40.360 --> 00:39:41.860
So that's why this
is all tricky,

00:39:41.860 --> 00:39:45.070
but basically, in
the forward pass,

00:39:45.070 --> 00:39:50.658
we are calculating the queries,
keys, and values based on x.

00:39:50.657 --> 00:39:52.449
So these are the keys,
queries, and values.

00:39:52.449 --> 00:39:54.444
Here, when I'm
computing the attention,

00:39:54.445 --> 00:39:58.019
I have the queries matrix
multiplying the piece.

00:39:58.019 --> 00:40:00.730
So this is the dot product in
parallel for all the queries

00:40:00.730 --> 00:40:03.400
and all the keys
in all the heads.

00:40:03.400 --> 00:40:06.160
So I failed to mention
that there's also

00:40:06.159 --> 00:40:08.679
the aspect of the heads, which
is also done all in parallel

00:40:08.679 --> 00:40:09.029
here.

00:40:09.030 --> 00:40:10.900
So we have the batch
dimension, the time dimension,

00:40:10.900 --> 00:40:12.369
and the head dimension,
and you end up

00:40:12.369 --> 00:40:14.779
with five-dimensional tensors,
and it's all really confusing.

00:40:14.780 --> 00:40:17.110
So I invite you to step through
it later and convince yourself

00:40:17.110 --> 00:40:19.059
that this is actually
doing the right thing.

00:40:19.059 --> 00:40:21.549
But basically, you have the
batch dimension, the head

00:40:21.550 --> 00:40:23.560
dimension and the
time dimension,

00:40:23.559 --> 00:40:25.250
and then you have
features at them.

00:40:25.250 --> 00:40:28.630
And so this is evaluating for
all the batch elements, for all

00:40:28.630 --> 00:40:31.300
the head elements, and
all the time elements,

00:40:31.300 --> 00:40:34.030
the simple Python that I gave
you earlier, which is query

00:40:34.030 --> 00:40:35.769
dot product p.

00:40:35.769 --> 00:40:38.949
Then here, we do a masked_fill,
and what this is doing

00:40:38.949 --> 00:40:44.259
is it's basically clamping the
attention between the nodes

00:40:44.260 --> 00:40:46.480
that are not supposed to
communicate to be negative

00:40:46.480 --> 00:40:47.110
infinity.

00:40:47.110 --> 00:40:48.485
And we're doing
negative infinity

00:40:48.485 --> 00:40:51.220
because we're about to softmax,
and so negative infinity will

00:40:51.219 --> 00:40:54.384
make basically the attention
that those elements be zero.

00:40:54.385 --> 00:40:56.590
And so here we are going
to basically end up

00:40:56.590 --> 00:41:03.370
with the weights, the affinities
between these nodes, optional

00:41:03.369 --> 00:41:03.880
dropout.

00:41:03.880 --> 00:41:08.460
And then here, attention
matrix multiply v is basically

00:41:08.460 --> 00:41:10.960
the gathering of the information
according to the affinities

00:41:10.960 --> 00:41:11.829
we calculated.

00:41:11.829 --> 00:41:14.529
And this is just a
weighted sum of the values

00:41:14.530 --> 00:41:15.769
at all those nodes.

00:41:15.769 --> 00:41:19.030
So this matrix multiplies
is doing that weighted sum.

00:41:19.030 --> 00:41:20.993
And then transpose
contiguous view

00:41:20.992 --> 00:41:22.659
because it's all
complicated and batched

00:41:22.659 --> 00:41:24.789
in five-dimensional
tensors, but it's really not

00:41:24.789 --> 00:41:26.889
doing anything,
optional drop out,

00:41:26.889 --> 00:41:30.679
and then a linear projection
back to the residual pathway.

00:41:30.679 --> 00:41:34.710
So this is implementing the
communication phase here.

00:41:34.710 --> 00:41:37.869
Then you can train
this transformer.

00:41:37.869 --> 00:41:41.170
And then you can generate
infinite Shakespeare.

00:41:41.170 --> 00:41:43.090
And you will simply do this by--

00:41:43.090 --> 00:41:47.170
because our block size is 8,
we start with a sum token,

00:41:47.170 --> 00:41:50.500
say like, I used
in this case, you

00:41:50.500 --> 00:41:53.050
can use something like a
new line as the start token.

00:41:53.050 --> 00:41:55.517
And then you communicate
only to yourself

00:41:55.516 --> 00:41:57.099
because there's a
single node, and you

00:41:57.099 --> 00:41:59.559
get the probability
distribution for the first word

00:41:59.559 --> 00:42:00.650
in the sequence.

00:42:00.650 --> 00:42:03.603
And then you decode it
for the first character

00:42:03.603 --> 00:42:04.269
in the sequence.

00:42:04.269 --> 00:42:05.559
You decode the character.

00:42:05.559 --> 00:42:06.549
And then you bring
back the character,

00:42:06.550 --> 00:42:08.019
and you re-encode
it as an integer.

00:42:08.019 --> 00:42:10.605
And now, you have
the second thing.

00:42:10.605 --> 00:42:12.760
And so you get--

00:42:12.760 --> 00:42:14.470
OK, we're at the first
position, and this

00:42:14.469 --> 00:42:17.659
is whatever integer it is,
add the positional encodings,

00:42:17.659 --> 00:42:19.659
goes into the sequence,
goes in the transformer,

00:42:19.659 --> 00:42:21.940
and again, this token
now communicates

00:42:21.940 --> 00:42:26.690
with the first token
and it's identity.

00:42:26.690 --> 00:42:28.389
And so you just keep
plugging it back.

00:42:28.389 --> 00:42:31.000
And once you run out of the
block size, which is eight,

00:42:31.000 --> 00:42:33.130
you start to crawl,
because you can never

00:42:33.130 --> 00:42:34.660
have watt size more than
eight in the way you've

00:42:34.659 --> 00:42:35.701
trained this transformer.

00:42:35.702 --> 00:42:37.690
So we have more and more
context until eight.

00:42:37.690 --> 00:42:39.190
And then if you want to
generate beyond eight,

00:42:39.190 --> 00:42:41.481
you have to start cropping
because the transformer only

00:42:41.481 --> 00:42:43.690
works for eight elements
in time dimension.

00:42:43.690 --> 00:42:47.170
And so all of these transformers
in the [INAUDIBLE] setting

00:42:47.170 --> 00:42:50.590
have a finite block
size or context length,

00:42:50.590 --> 00:42:54.460
and in typical models, this
will be 1,024 tokens or 2,048

00:42:54.460 --> 00:42:56.349
tokens, something like that.

00:42:56.349 --> 00:42:58.559
But these tokens are
usually like BPE tokens,

00:42:58.559 --> 00:43:00.434
or SentencePiece tokens,
or WorkPiece tokens.

00:43:00.434 --> 00:43:02.539
There's many
different encodings.

00:43:02.539 --> 00:43:03.860
So it's not like that long.

00:43:03.860 --> 00:43:05.349
And so that's why, I
think, [INAUDIBLE]..

00:43:05.349 --> 00:43:06.789
We really want to
expand the context size,

00:43:06.789 --> 00:43:08.469
and it gets gnarly
because the attention

00:43:08.469 --> 00:43:11.659
is sporadic in the
[INAUDIBLE] case.

00:43:11.659 --> 00:43:16.759
Now, if you want to implement
an encoder instead of a decoder

00:43:16.760 --> 00:43:18.680
attention.

00:43:18.679 --> 00:43:21.214
Then all you have to
do is this [INAUDIBLE]

00:43:21.215 --> 00:43:23.340
and you just delete that line.

00:43:23.340 --> 00:43:25.414
So if you don't
mask the attention,

00:43:25.414 --> 00:43:27.289
then all the nodes
communicate to each other,

00:43:27.289 --> 00:43:29.389
and everything is
allowed, and information

00:43:29.389 --> 00:43:31.129
flows between all the nodes.

00:43:31.130 --> 00:43:35.750
So if you want to have the
encoder here, just delete.

00:43:35.750 --> 00:43:38.030
All the encoder blocks
will use attention

00:43:38.030 --> 00:43:39.380
where this line is deleted.

00:43:39.380 --> 00:43:40.730
That's it.

00:43:40.730 --> 00:43:44.480
So you're allowing whatever--
this encoder might store say,

00:43:44.480 --> 00:43:46.880
10 tokens, 10 nodes,
and they are all

00:43:46.880 --> 00:43:51.240
allowed to communicate to each
other going up the transformer.

00:43:51.239 --> 00:43:53.369
And then if you want to
implement cross-attention,

00:43:53.369 --> 00:43:55.327
so you have a full
encoder-decoder transformer,

00:43:55.327 --> 00:43:59.329
not just a decoder-only
transformer or a GPT.

00:43:59.329 --> 00:44:03.159
Then we need to also add
cross-attention in the middle.

00:44:03.159 --> 00:44:05.809
So here, there is a
self-attention piece where all

00:44:05.809 --> 00:44:06.469
the--

00:44:06.469 --> 00:44:08.802
there's a self-attention
piece, a cross-attention piece,

00:44:08.802 --> 00:44:09.980
and this MLP.

00:44:09.980 --> 00:44:12.320
And in the
cross-attention, we need

00:44:12.320 --> 00:44:14.570
to take the features from
the top of the encoder.

00:44:14.570 --> 00:44:16.789
We need to add one
more line here,

00:44:16.789 --> 00:44:20.090
and this would be the
cross-attention instead of a--

00:44:20.090 --> 00:44:22.340
I should have implemented
it instead of just pointing,

00:44:22.340 --> 00:44:23.300
I think.

00:44:23.300 --> 00:44:25.310
But there will be a
cross-attention line here.

00:44:25.309 --> 00:44:26.929
So we'll have three
lines because we

00:44:26.929 --> 00:44:28.190
need to add another block.

00:44:28.190 --> 00:44:31.400
And the queries will
come from x but the keys

00:44:31.400 --> 00:44:35.043
and the values will come
from the top of the encoder.

00:44:35.043 --> 00:44:36.710
And there will be
basic code information

00:44:36.710 --> 00:44:38.126
flowing from the
encoder, strictly

00:44:38.126 --> 00:44:41.420
to all the nodes inside x.

00:44:41.420 --> 00:44:42.750
And then that's it.

00:44:42.750 --> 00:44:44.255
So it's a very
simple modifications

00:44:44.255 --> 00:44:47.369
on the decoder attention.

00:44:47.369 --> 00:44:49.469
So you'll hear people
talk that you have

00:44:49.469 --> 00:44:51.884
a decoder-only model like GPT.

00:44:51.885 --> 00:44:53.760
You can have an encoder-only
model like BERT,

00:44:53.760 --> 00:44:55.427
or you can have an
encoder-decoder model

00:44:55.427 --> 00:44:59.660
like say T5, doing things
like machine translation.

00:44:59.659 --> 00:45:04.143
And in BERT, you can't train
it using this language modeling

00:45:04.143 --> 00:45:06.059
setup that's utter
aggressive, and you're just

00:45:06.059 --> 00:45:07.340
trying to predict next
[INAUDIBLE] in the sequence.

00:45:07.340 --> 00:45:09.720
You're training it doing
slightly different objectives.

00:45:09.719 --> 00:45:12.000
You're putting in
the full sentence,

00:45:12.000 --> 00:45:14.454
and, the full sentence is
allowed to communicate fully.

00:45:14.454 --> 00:45:16.829
And then you're trying to
classify sentiment or something

00:45:16.829 --> 00:45:18.039
like that.

00:45:18.039 --> 00:45:21.489
So you're not trying to model
the next token in the sequence.

00:45:21.489 --> 00:45:26.649
So these are trained
slightly different

00:45:26.650 --> 00:45:31.789
using masking and other
denoising techniques.

00:45:31.789 --> 00:45:32.289
OK.

00:45:32.289 --> 00:45:34.570
So that's like the transformer.

00:45:34.570 --> 00:45:36.410
I'm going to continue.

00:45:36.409 --> 00:45:38.565
So yeah, maybe more questions.

00:45:38.565 --> 00:45:49.349
[INAUDIBLE]

00:46:01.710 --> 00:46:06.030
This is like we are enforcing
these constraints on it

00:46:06.030 --> 00:46:12.610
by just masking [INAUDIBLE]

00:46:12.610 --> 00:46:14.039
So I'm not sure
if I fully follow.

00:46:14.039 --> 00:46:16.769
So there's different ways
to look at this analogy,

00:46:16.769 --> 00:46:18.329
but one analogy is
you can interpret

00:46:18.329 --> 00:46:20.199
this graph as really fixed.

00:46:20.199 --> 00:46:22.230
It's just that every time
we do the communicate,

00:46:22.230 --> 00:46:23.400
we are using different weights.

00:46:23.400 --> 00:46:24.610
You can look at it that way.

00:46:24.610 --> 00:46:26.680
So if we have block size
of eight in my example,

00:46:26.679 --> 00:46:27.762
we would have eight nodes.

00:46:27.762 --> 00:46:29.309
Here we have 2, 4, 6.

00:46:29.309 --> 00:46:30.989
OK, so we'd have eight nodes.

00:46:30.989 --> 00:46:33.042
They would be connected in--

00:46:33.043 --> 00:46:35.460
you lay them out, and you only
connect from left to right.

00:46:35.460 --> 00:46:37.860
[INAUDIBLE]

00:46:42.635 --> 00:46:44.010
Why would they
connect-- usually,

00:46:44.010 --> 00:46:46.410
the connections don't change
as a function of the data

00:46:46.409 --> 00:46:47.460
or something like that--

00:46:47.460 --> 00:46:51.990
[INAUDIBLE]

00:47:00.293 --> 00:47:02.210
I don't think I've seen
a single example where

00:47:02.210 --> 00:47:03.139
the connectivity
changes dynamically

00:47:03.139 --> 00:47:04.021
in the function data.

00:47:04.021 --> 00:47:05.480
Usually, the
connectivity is fixed.

00:47:05.480 --> 00:47:07.610
If you have an encoder,
and you're training a BERT,

00:47:07.610 --> 00:47:09.500
you have how many
tokens you want,

00:47:09.500 --> 00:47:11.639
and they are fully connected.

00:47:11.639 --> 00:47:13.539
And if you have a
decoder-only model,

00:47:13.539 --> 00:47:15.289
you have this triangular
thing, and if you

00:47:15.289 --> 00:47:16.748
have encoder-decoder,
then you have

00:47:16.748 --> 00:47:21.269
awkwardly two pools of nodes.

00:47:21.269 --> 00:47:21.769
Yeah.

00:47:24.639 --> 00:47:25.230
Go ahead.

00:47:25.230 --> 00:47:45.010
[INAUDIBLE] I wonder, you
know much more about this

00:47:45.010 --> 00:47:46.604
than I know.

00:47:46.603 --> 00:48:00.629
But do you have a sense of
like if you ran [INAUDIBLE]

00:48:00.630 --> 00:48:08.555
In my head, I'm thinking
[INAUDIBLE] but then you also

00:48:08.554 --> 00:48:13.099
have different things for
one or more of [INAUDIBLE]----

00:48:13.099 --> 00:48:15.000
Yeah, it's really
hard to say, so that's

00:48:15.000 --> 00:48:17.219
why I think this paper is so
interesting because like, yeah,

00:48:17.219 --> 00:48:18.569
usually, you'd
see like the path,

00:48:18.570 --> 00:48:19.680
and maybe they had
path internally.

00:48:19.679 --> 00:48:20.981
They just didn't publish it.

00:48:20.981 --> 00:48:23.565
All you can see is things that
didn't look like a transformer.

00:48:23.565 --> 00:48:26.250
I mean, you have ResNets,
which have lots of this.

00:48:26.250 --> 00:48:29.820
But a ResNet would be
like this, but there's

00:48:29.820 --> 00:48:31.200
no self-attention component.

00:48:31.199 --> 00:48:35.579
But the MLP is there
kind of in a ResNet.

00:48:35.579 --> 00:48:37.710
So a ResNet looks
very much like this

00:48:37.710 --> 00:48:40.349
except there's no-- you can
use layer norms in ResNets,

00:48:40.349 --> 00:48:41.219
I believe, as well.

00:48:41.219 --> 00:48:43.509
Typically, sometimes,
they can be batch norms.

00:48:43.510 --> 00:48:45.210
So it is kind of like a ResNet.

00:48:45.210 --> 00:48:47.190
It is like they took
a ResNet, and they

00:48:47.190 --> 00:48:50.369
put in a self-attention
block in addition

00:48:50.369 --> 00:48:52.139
to the preexisting
MLP block, which

00:48:52.139 --> 00:48:53.741
is kind of like convolutions.

00:48:53.742 --> 00:48:55.575
And MLP was strictly
speaking deconvolution,

00:48:55.574 --> 00:48:59.099
one by one convolution,
but I think

00:48:59.099 --> 00:49:04.110
the idea is similar in that MLP
is just like a typical weights,

00:49:04.110 --> 00:49:06.210
nonlinearity weights operation.

00:49:11.047 --> 00:49:13.089
But I will say, yeah, this
is kind of interesting

00:49:13.090 --> 00:49:15.968
because a lot of
work is not there,

00:49:15.967 --> 00:49:17.634
and then they give
you this transformer.

00:49:17.635 --> 00:49:18.820
And then it turns
out 5 years later,

00:49:18.820 --> 00:49:20.860
it's not changed, even though
everyone's trying to change it.

00:49:20.860 --> 00:49:23.095
So it's interesting to me
that it's like a package,

00:49:23.094 --> 00:49:25.487
in like a package,
which I think is really

00:49:25.487 --> 00:49:26.529
interesting historically.

00:49:26.530 --> 00:49:30.100
And I also talked
to paper authors,

00:49:30.099 --> 00:49:32.116
and they were
unaware of the impact

00:49:32.117 --> 00:49:33.950
that the transformer
would have at the time.

00:49:33.949 --> 00:49:37.419
So when you read this paper,
actually, it's unfortunate

00:49:37.420 --> 00:49:39.548
because this is the paper
that changed everything,

00:49:39.547 --> 00:49:41.589
but when people read it,
it's like question marks

00:49:41.590 --> 00:49:45.100
because it reads like a pretty
random machine translation

00:49:45.099 --> 00:49:46.139
paper.

00:49:46.139 --> 00:49:47.304
It's like, oh, we're
doing machine translation.

00:49:47.304 --> 00:49:48.596
Oh, here's a cool architecture.

00:49:48.597 --> 00:49:51.265
OK, great, good results.

00:49:51.264 --> 00:49:53.589
It doesn't know what's
going to happen.

00:49:53.590 --> 00:49:56.260
[LAUGHS] And so when
people read it today,

00:49:56.260 --> 00:50:00.550
I think they're
confused potentially.

00:50:00.550 --> 00:50:02.152
I will have some
tweets at the end,

00:50:02.152 --> 00:50:03.610
but I think I would
have renamed it

00:50:03.610 --> 00:50:08.755
with the benefit of hindsight
of like, well, I'll get to it.

00:50:08.755 --> 00:50:15.112
[INAUDIBLE]

00:50:20.920 --> 00:50:22.990
Yeah, I think that's a
good question as well.

00:50:22.989 --> 00:50:24.719
Currently, I mean,
I certainly don't

00:50:24.719 --> 00:50:27.329
love the autoregressive
modeling approach.

00:50:27.329 --> 00:50:29.250
I think it's kind of
weird to sample a token

00:50:29.250 --> 00:50:31.195
and then commit to it.

00:50:31.195 --> 00:50:36.809
So maybe there are
some ways, some hybrids

00:50:36.809 --> 00:50:38.309
with the Fusion as
an example, which

00:50:38.309 --> 00:50:41.409
I think would be
really cool, or we'll

00:50:41.409 --> 00:50:44.319
find some other ways to edit
the sequences later but still

00:50:44.320 --> 00:50:47.177
in our regressive framework.

00:50:47.177 --> 00:50:49.510
But I think the Fusion is
like an up and coming modeling

00:50:49.510 --> 00:50:51.677
approach that I personally
find much more appealing.

00:50:51.677 --> 00:50:54.190
When I sample text, I don't
go chunk, chunk, chunk,

00:50:54.190 --> 00:50:55.365
and commit.

00:50:55.364 --> 00:50:58.299
I do a draft one, and then
I do a better draft two.

00:50:58.300 --> 00:51:00.880
And that feels like
a diffusion process.

00:51:00.880 --> 00:51:02.480
So that would be my hope.

00:51:05.449 --> 00:51:07.759
OK, also a question.

00:51:07.760 --> 00:51:20.338
So yeah, you'd think
the [INAUDIBLE]

00:51:20.338 --> 00:51:21.880
And then once we
have the edge rates,

00:51:21.880 --> 00:51:23.894
we just have to multiply
it by the values,

00:51:23.894 --> 00:51:25.269
and then you just
[INAUDIBLE] it.

00:51:25.269 --> 00:51:27.159
Yes, yeah, it's right.

00:51:27.159 --> 00:51:30.339
And you think there's MLG
within graph neural networks

00:51:30.340 --> 00:51:32.590
and they'll potentially--

00:51:32.590 --> 00:51:34.990
I find the graph neural
networks like a confusing term

00:51:34.989 --> 00:51:38.209
because, I mean,
yeah, previously,

00:51:38.210 --> 00:51:40.262
there, was this notion of--

00:51:40.262 --> 00:51:42.429
I feel like maybe today
everything is a graph neural

00:51:42.429 --> 00:51:44.799
network because a transformer
is a graph neural network

00:51:44.800 --> 00:51:45.760
processor.

00:51:45.760 --> 00:51:48.260
The native representation that
the transformer operates over

00:51:48.260 --> 00:51:51.680
is sets that are connected
by edges in a direct way.

00:51:51.679 --> 00:51:55.636
And so that's the native
representation, and then, yeah.

00:51:55.637 --> 00:51:57.720
OK, I should go on because
I still have 30 slides.

00:51:57.719 --> 00:51:59.539
[INAUDIBLE]

00:52:08.099 --> 00:52:11.339
Oh yeah, yeah, the root
DE, I think, it basically

00:52:11.340 --> 00:52:14.130
like if you're initializing
with random weights

00:52:14.130 --> 00:52:17.140
setup from a [INAUDIBLE] as
your dimension size grows,

00:52:17.139 --> 00:52:19.349
so does your values,
the variance grows.

00:52:19.349 --> 00:52:23.400
And then your softmax will just
become the one half vector.

00:52:23.400 --> 00:52:25.410
So it's just a way to
control the variance

00:52:25.409 --> 00:52:28.049
and bring it to always be
in a good range for softmax

00:52:28.050 --> 00:52:31.670
and nice diffused distribution.

00:52:31.670 --> 00:52:37.869
OK, so it's almost like
an initialization thing.

00:52:37.869 --> 00:52:41.469
OK, so transformers
have been applied

00:52:41.469 --> 00:52:44.319
to all the other fields,
and the way this was done

00:52:44.320 --> 00:52:46.900
is in my opinion,
ridiculous ways

00:52:46.900 --> 00:52:49.389
honestly because I was a
computer vision person,

00:52:49.389 --> 00:52:51.400
and you have ComNets,
and they make sense.

00:52:51.400 --> 00:52:53.840
So what we're doing now
with VITs as an example is

00:52:53.840 --> 00:52:56.215
you take an image and you chop
it up into little squares.

00:52:56.215 --> 00:52:57.802
And then those
squares, literally,

00:52:57.802 --> 00:52:59.260
feed into a
transformer, and that's

00:52:59.260 --> 00:53:01.900
it, which is kind of ridiculous.

00:53:01.900 --> 00:53:06.389
And so, I mean, yeah,
and so the transformer

00:53:06.389 --> 00:53:08.670
doesn't even, in the simplest
case, really know where

00:53:08.670 --> 00:53:10.470
these patches might come from.

00:53:10.469 --> 00:53:12.929
They are usually
positionally encoded,

00:53:12.929 --> 00:53:16.379
but it has to rediscover
a lot of the structure,

00:53:16.380 --> 00:53:19.180
I think, of them in some ways.

00:53:19.179 --> 00:53:23.089
And it's kind of weird
to approach it that way.

00:53:23.090 --> 00:53:25.579
But it's just the
simplest baseline

00:53:25.579 --> 00:53:27.672
of just chomping up big
images into small squares

00:53:27.672 --> 00:53:29.839
and feeding them in as the
individual nodes actually

00:53:29.840 --> 00:53:30.620
works fairly well.

00:53:30.619 --> 00:53:32.690
And then this is in a
transformer encoder,

00:53:32.690 --> 00:53:34.760
so all the patches are
talking to each other

00:53:34.760 --> 00:53:36.960
throughout the
entire transformer.

00:53:36.960 --> 00:53:39.494
And the number of nodes
here would be like nine.

00:53:42.284 --> 00:53:44.909
Also, in speech recognition, you
just take your melSpectrogram,

00:53:44.909 --> 00:53:46.937
and you chop it up into
slices and you feed them

00:53:46.938 --> 00:53:47.730
into a transformer.

00:53:47.730 --> 00:53:49.920
So there was paper like
this, but also Whisper.

00:53:49.920 --> 00:53:51.720
Whisper is a
copy-paste transformer.

00:53:51.719 --> 00:53:55.199
If you saw Whisper from OpenAI,
you just chop up melSpectrogram

00:53:55.199 --> 00:53:57.547
and feed it into a
transformer and then pretend

00:53:57.547 --> 00:53:58.589
you're dealing with text.

00:53:58.590 --> 00:54:00.870
And it works very well.

00:54:00.869 --> 00:54:03.692
Decision transformer in RL,
you take your states, actions,

00:54:03.693 --> 00:54:05.610
and reward that you
experience in environment,

00:54:05.610 --> 00:54:07.693
and you just pretend
it's a language.

00:54:07.693 --> 00:54:09.610
Then you start to model
the sequences of that,

00:54:09.610 --> 00:54:11.640
and then you can use
that for planning later.

00:54:11.639 --> 00:54:13.319
That works really well.

00:54:13.320 --> 00:54:15.382
Even things AlphaFold,
so we were briefly

00:54:15.382 --> 00:54:17.590
talking about molecules and
how you can plug them in.

00:54:17.590 --> 00:54:19.507
So at the heart of
AlphaFold, computationally,

00:54:19.507 --> 00:54:21.907
is also a transformer.

00:54:21.907 --> 00:54:23.949
One thing I wanted to also
say about transformers

00:54:23.949 --> 00:54:26.289
is I find that
they're very flexible,

00:54:26.289 --> 00:54:28.150
and I really enjoy that.

00:54:28.150 --> 00:54:31.228
I'll give you an
example from Tesla.

00:54:31.228 --> 00:54:32.769
You have a ComNet
that takes an image

00:54:32.769 --> 00:54:34.300
and makes predictions
about the image.

00:54:34.300 --> 00:54:35.967
And then the big
question is, how do you

00:54:35.967 --> 00:54:37.269
feed in extra information?

00:54:37.269 --> 00:54:38.920
And it's not always
trivial like say, I

00:54:38.920 --> 00:54:40.389
had additional
information that I

00:54:40.389 --> 00:54:43.480
want to inform that I want
the outputs to be informed by.

00:54:43.480 --> 00:54:45.112
Maybe I have other
sensors like Radar.

00:54:45.112 --> 00:54:47.320
Maybe I have some map
information, or a vehicle type,

00:54:47.320 --> 00:54:48.085
or some audio.

00:54:48.085 --> 00:54:50.710
And the question is, how do you
feed information into a ComNet?

00:54:50.710 --> 00:54:52.329
Like where do you feed it in?

00:54:52.329 --> 00:54:54.429
Do you concatenate it?

00:54:54.429 --> 00:54:55.210
Do you add it?

00:54:55.210 --> 00:54:56.349
At what stage?

00:54:56.349 --> 00:54:58.202
And so with a transformer,
it's much easier

00:54:58.202 --> 00:55:00.369
because you just take
whatever you want, you chop it

00:55:00.369 --> 00:55:02.500
up into pieces, and you
feed it in with a set

00:55:02.500 --> 00:55:03.500
of what you had before.

00:55:03.500 --> 00:55:04.690
And you let the
self-attention figure out

00:55:04.690 --> 00:55:06.106
how everything
should communicate.

00:55:06.106 --> 00:55:07.719
And that actually
apparently works.

00:55:07.719 --> 00:55:10.119
So just chop up everything
and throw it into the mix

00:55:10.119 --> 00:55:11.739
is like the way.

00:55:11.739 --> 00:55:15.759
And it frees neural
nets from this burgeon

00:55:15.760 --> 00:55:19.332
of Euclidean space,
where previously you

00:55:19.331 --> 00:55:21.789
had to arrange your computation
to conform to the Euclidean

00:55:21.789 --> 00:55:25.304
space or three dimensions of how
you're laying out the compute.

00:55:25.304 --> 00:55:26.679
Like the compute
actually kind of

00:55:26.679 --> 00:55:29.859
happens in almost like 3D
space if you think about it.

00:55:29.860 --> 00:55:32.050
But in attention,
everything is just sets.

00:55:32.050 --> 00:55:33.730
So it's a very
flexible framework,

00:55:33.730 --> 00:55:35.530
and you can just throw in stuff
into your conditioning set.

00:55:35.530 --> 00:55:37.155
And everything just
self-attended over.

00:55:37.155 --> 00:55:39.595
So it's quite beautiful
from that perspective.

00:55:39.594 --> 00:55:43.219
OK, so now what exactly makes
transformers so effective?

00:55:43.219 --> 00:55:44.719
I think a good
example of this comes

00:55:44.719 --> 00:55:48.230
from the GPT-3 paper, which
I encourage people to read.

00:55:48.230 --> 00:55:50.280
Language Models of
Few-Shot Learners.

00:55:50.280 --> 00:55:52.280
I would have probably
renamed this a little bit.

00:55:52.280 --> 00:55:54.380
I would have said
something like transformers

00:55:54.380 --> 00:55:57.769
are capable of in-context
learning or meta-learning.

00:55:57.769 --> 00:56:00.097
That's like what makes
them really special.

00:56:00.097 --> 00:56:02.180
So basically the setting
that they're working with

00:56:02.179 --> 00:56:03.679
is, OK, I have some
context, and I'm

00:56:03.679 --> 00:56:04.887
trying-- like say, a passage.

00:56:04.887 --> 00:56:06.335
This is just one
example of many.

00:56:06.335 --> 00:56:08.840
I have a passage, and I'm
asking questions about it.

00:56:08.840 --> 00:56:12.762
And then as part of the
context in the prompt,

00:56:12.762 --> 00:56:14.470
I'm giving the questions
and the answers.

00:56:14.469 --> 00:56:16.009
So I'm giving one example
of question-answer,

00:56:16.010 --> 00:56:17.468
another example of
question-answer,

00:56:17.467 --> 00:56:19.889
another example of
question-answer, and so on.

00:56:19.889 --> 00:56:21.799
And this becomes--

00:56:21.800 --> 00:56:24.289
Oh yeah, people are going
to have to leave soon, huh?

00:56:24.289 --> 00:56:25.634
OK, is this really important?

00:56:25.635 --> 00:56:26.177
Let me think.

00:56:29.454 --> 00:56:31.329
OK, so what's really
interesting is basically

00:56:31.329 --> 00:56:35.380
like with more examples
given in a context,

00:56:35.380 --> 00:56:37.200
the accuracy improves.

00:56:37.199 --> 00:56:39.199
And so what that can set
is that the transformer

00:56:39.199 --> 00:56:42.159
is able to somehow
learn in the activations

00:56:42.159 --> 00:56:43.629
without doing any
gradient descent

00:56:43.630 --> 00:56:45.260
in a typical
fine-tuning fashion.

00:56:45.260 --> 00:56:48.460
So if you fine-tune, you have to
give an example and the answer,

00:56:48.460 --> 00:56:51.246
and you fine-tune it,
using gradient descent.

00:56:51.246 --> 00:56:53.079
But it looks like the
transformer internally

00:56:53.079 --> 00:56:54.519
in its weights is
doing something

00:56:54.519 --> 00:56:56.050
that looks like potentially
gradient, some kind

00:56:56.050 --> 00:56:57.430
of a metalearning in the
weights of the transformer

00:56:57.429 --> 00:56:59.049
as it is reading the prompt.

00:56:59.050 --> 00:57:01.678
And so in this paper,
they go into, OK,

00:57:01.677 --> 00:57:03.969
distinguishing this outer
loop with stochastic gradient

00:57:03.969 --> 00:57:06.302
descent in this inner loop
of the intercontext learning.

00:57:06.302 --> 00:57:08.679
So the inner loop is as
the transformer is reading

00:57:08.679 --> 00:57:12.339
the sequence almost and the
outer loop is the training

00:57:12.340 --> 00:57:14.032
by gradient descent.

00:57:14.032 --> 00:57:15.490
So basically,
there's some training

00:57:15.489 --> 00:57:17.019
happening in the activations
of the transformer

00:57:17.019 --> 00:57:18.730
as it is consuming
a sequence that

00:57:18.730 --> 00:57:21.099
may be very much looks
like gradient descent.

00:57:21.099 --> 00:57:23.307
And so there are some recent
papers that hint at this

00:57:23.307 --> 00:57:23.929
and study it.

00:57:23.929 --> 00:57:25.387
And so as an example,
in this paper

00:57:25.387 --> 00:57:28.719
here, they propose something
called the draw operator.

00:57:28.719 --> 00:57:32.072
And they argue that the
raw operator is implemented

00:57:32.072 --> 00:57:33.489
by transformer,
and then they show

00:57:33.489 --> 00:57:35.289
that you can implement
things like ridge regression

00:57:35.289 --> 00:57:36.599
on top of the raw operator.

00:57:36.599 --> 00:57:39.011
And so this is giving--

00:57:39.012 --> 00:57:40.720
There are papers
hinting that maybe there

00:57:40.719 --> 00:57:42.927
is some thing that looks
like gradient-based learning

00:57:42.927 --> 00:57:45.250
inside the activations
of the transformer.

00:57:45.250 --> 00:57:47.590
And I think this is not
impossible to think through

00:57:47.590 --> 00:57:49.720
because what is
gradient-based learning?

00:57:49.719 --> 00:57:52.179
Overpass, backward
pass, and then update.

00:57:52.179 --> 00:57:54.250
Oh, that looks like
a ResNet, right,

00:57:54.250 --> 00:57:57.099
because you're adding
to the weights.

00:57:57.099 --> 00:57:59.511
So the start of initial
random set of weights,

00:57:59.512 --> 00:58:01.720
forward pass, backward pass,
and update your weights,

00:58:01.719 --> 00:58:04.096
and then forward pass, backward
pass, update the weights.

00:58:04.097 --> 00:58:04.930
Looks like a ResNet.

00:58:04.929 --> 00:58:10.179
Transformer is a ResNet,
so much more hand-wavey,

00:58:10.179 --> 00:58:11.889
but basically, some
papers are trying

00:58:11.889 --> 00:58:14.525
to hint at why that would
be potentially possible.

00:58:14.525 --> 00:58:16.900
And then I have a bunch of
tweets I just copy-pasted here

00:58:16.900 --> 00:58:18.639
in the end.

00:58:18.639 --> 00:58:20.519
This was like meant for
general consumption,

00:58:20.519 --> 00:58:22.900
so they're a bit more high-level
and hypey a little bit.

00:58:22.900 --> 00:58:26.079
But I'm talking about why this
architecture is so interesting

00:58:26.079 --> 00:58:27.994
and why potentially
it became so popular.

00:58:27.994 --> 00:58:29.619
And I think it
simultaneously optimizes

00:58:29.619 --> 00:58:31.464
three properties that, I
think, are very desirable.

00:58:31.465 --> 00:58:33.130
Number one, the
transformer is very

00:58:33.130 --> 00:58:35.865
expressive in the forward pass.

00:58:35.865 --> 00:58:37.509
It sort of like it's
able to implement

00:58:37.510 --> 00:58:39.552
very interesting functions,
potentially functions

00:58:39.552 --> 00:58:41.920
that can even do meta-learning.

00:58:41.920 --> 00:58:43.659
Number two, it is very
optimizable thanks

00:58:43.659 --> 00:58:45.429
to things like residual
connections, layer nodes,

00:58:45.429 --> 00:58:45.940
and so on.

00:58:45.940 --> 00:58:47.731
And number three, it's
extremely efficient.

00:58:47.731 --> 00:58:49.929
This is not always appreciated,
but the transformer,

00:58:49.929 --> 00:58:51.554
if you look at the
computational graph,

00:58:51.554 --> 00:58:53.649
is a shallow, wide
network, which

00:58:53.650 --> 00:58:56.224
is perfect to take advantage
of the parallelism of GPUs.

00:58:56.224 --> 00:58:58.599
So I think the transformer
was designed very deliberately

00:58:58.599 --> 00:59:00.730
to run efficiently on GPUs.

00:59:00.730 --> 00:59:02.650
There's previous
work like neural GPU

00:59:02.650 --> 00:59:05.680
that I really enjoy as
well, which is really just

00:59:05.679 --> 00:59:08.559
like how do we design neural
nets that are efficient on GPUs

00:59:08.559 --> 00:59:10.420
and thinking backwards from the
constraints of the hardware,

00:59:10.420 --> 00:59:11.740
which I think is a
very interesting way

00:59:11.739 --> 00:59:12.489
to think about it.

00:59:17.929 --> 00:59:21.789
Oh yeah, so here, I'm saying,
I probably would have called--

00:59:21.789 --> 00:59:24.489
I probably would've called the
transformer a general purpose

00:59:24.489 --> 00:59:27.819
efficient optimizable
computer instead of attention

00:59:27.820 --> 00:59:28.570
is all you need.

00:59:28.570 --> 00:59:31.930
That's what I would have maybe
in hindsight called that paper.

00:59:31.929 --> 00:59:37.349
It's proposing a model that
is very general purpose, so

00:59:37.349 --> 00:59:38.539
forward passes, expressive.

00:59:38.539 --> 00:59:40.759
It's very efficient
in terms of GPU usage

00:59:40.760 --> 00:59:44.720
and is easily optimizable by
gradient descent and trains

00:59:44.719 --> 00:59:46.511
very nicely.

00:59:46.512 --> 00:59:48.730
And then I have some
other hype tweets here.

00:59:51.489 --> 00:59:53.339
Anyway, so you can
read them later.

00:59:53.340 --> 00:59:55.090
But I think this one
is maybe interesting.

00:59:55.090 --> 00:59:58.360
So if previous neural nets
are special purpose computers

00:59:58.360 --> 01:00:00.490
designed for a
specific task, GPT

01:00:00.489 --> 01:00:03.789
is a general purpose computer,
reconfigurable at runtime

01:00:03.789 --> 01:00:06.039
to run natural
language programs.

01:00:06.039 --> 01:00:08.920
So the programs are
given as prompts,

01:00:08.920 --> 01:00:12.220
and then GPT runs the program
by completing the document.

01:00:12.219 --> 01:00:16.959
So I really like these analogies
personally to computer.

01:00:16.960 --> 01:00:18.639
It's just like a
powerful computer,

01:00:18.639 --> 01:00:22.199
and it's optimizable
by gradient descent.

01:00:22.199 --> 01:00:30.614
And I don't know--

01:00:30.614 --> 01:00:31.114
OK, yeah.

01:00:31.114 --> 01:00:31.614
That's it.

01:00:31.614 --> 01:00:33.376
[LAUGHTER]

01:00:33.376 --> 01:00:35.460
You can read the tweets
later, but that's for now.

01:00:35.460 --> 01:00:36.050
I'll just thank you.

01:00:36.050 --> 01:00:37.050
I'll just leave this up.

01:00:45.367 --> 01:00:46.659
Sorry, I just found this tweet.

01:00:46.659 --> 01:00:49.599
So turns out that if you
scale up the training set

01:00:49.599 --> 01:00:51.940
and use a powerful enough
neural net like a transformer,

01:00:51.940 --> 01:00:53.815
the network becomes a
kind of general purpose

01:00:53.815 --> 01:00:54.720
computer over text.

01:00:54.719 --> 01:00:56.527
So I think that's nice
way to look at it.

01:00:56.527 --> 01:00:58.569
And instead of performing
a single text sequence,

01:00:58.570 --> 01:01:00.340
you can design the
sequence in the prompt.

01:01:00.340 --> 01:01:02.230
And because the transformer
is both powerful

01:01:02.230 --> 01:01:05.110
but also is trained on large
enough, very hard data set,

01:01:05.110 --> 01:01:07.539
it becomes this general
purpose text computer.

01:01:07.539 --> 01:01:11.199
And so I think that's kind of
interesting way to look at it.

01:01:11.199 --> 01:01:13.371
Yeah.

01:01:13.371 --> 01:01:16.750
[INAUDIBLE]

01:02:01.289 --> 01:02:04.179
And I guess my question
is [INAUDIBLE] how

01:02:04.179 --> 01:02:05.597
much do you think [INAUDIBLE]?

01:02:10.019 --> 01:02:25.795
really because it's mostly
more efficient or [INAUDIBLE]

01:02:25.795 --> 01:02:27.170
So I think there's
a bit of that.

01:02:27.170 --> 01:02:29.284
Yeah, so I would say
RNNs in principle,

01:02:29.284 --> 01:02:31.456
yes, they can implement
arbitrary programs.

01:02:31.456 --> 01:02:33.664
I think, it's like a useless
statement to some extent

01:02:33.664 --> 01:02:35.795
because they're probably--

01:02:35.795 --> 01:02:37.670
I'm not sure that they're
probably expressive

01:02:37.670 --> 01:02:40.369
because in a sense of power
and that they can implement

01:02:40.369 --> 01:02:43.069
these arbitrary functions.

01:02:43.070 --> 01:02:44.250
But they're not optimizable.

01:02:44.250 --> 01:02:46.250
And they're certainly not
efficient because they

01:02:46.250 --> 01:02:47.750
are serial computing devices.

01:02:50.163 --> 01:02:51.829
So if you look at it
as a compute graph,

01:02:51.829 --> 01:02:58.264
RNNs are very long,
thin compute graph.

01:02:58.264 --> 01:03:00.650
What if you stretched out
the neurons and you looked--

01:03:00.650 --> 01:03:02.255
like take all the individual
neurons interconnectivity,

01:03:02.255 --> 01:03:04.460
and stretch them out, and
try to visualize them.

01:03:04.460 --> 01:03:07.070
RNNs would be like a very
long graph and that's bad.

01:03:07.070 --> 01:03:08.570
And it's bad also
for optimizability

01:03:08.570 --> 01:03:10.980
because I don't
exactly know why,

01:03:10.980 --> 01:03:13.789
but just the rough intuition
is when you're backpropagating,

01:03:13.789 --> 01:03:15.574
you don't want to
make too many steps.

01:03:15.574 --> 01:03:19.384
And so transformers are a
shallow wide graph, and so

01:03:19.385 --> 01:03:23.983
from supervision to inputs is
a very small number of hops.

01:03:23.983 --> 01:03:25.400
And it's a long
residual pathways,

01:03:25.400 --> 01:03:26.990
which make gradients
flow very easily.

01:03:26.989 --> 01:03:28.364
And there's all
these layer norms

01:03:28.364 --> 01:03:32.509
to control the scales of
all of those activations.

01:03:32.510 --> 01:03:34.910
And so there's
not too many hops,

01:03:34.909 --> 01:03:36.980
and you're going from
supervision to input

01:03:36.980 --> 01:03:40.840
very quickly and just
flows through the graph.

01:03:40.840 --> 01:03:42.420
And it can all be
done in parallel,

01:03:42.420 --> 01:03:43.724
so you don't need to do this--

01:03:43.724 --> 01:03:46.029
encoder and decoder RNNs, you
have to go from first word,

01:03:46.030 --> 01:03:47.447
then second word,
then third word.

01:03:47.447 --> 01:03:49.329
But here in transformer,
every single word

01:03:49.329 --> 01:03:54.699
was processed completely in
parallel, which is kind of a--

01:03:54.699 --> 01:03:57.039
So I think all of these are
really important because all

01:03:57.039 --> 01:03:57.719
of these are really important.

01:03:57.719 --> 01:04:00.399
And I think number 3 is less
talked about but extremely

01:04:00.400 --> 01:04:03.710
important because in deep
learning scale matters.

01:04:03.710 --> 01:04:06.099
And so the size of the
network that you can train it

01:04:06.099 --> 01:04:08.509
gives you is
extremely important.

01:04:08.510 --> 01:04:10.580
And so if it's efficient
on the current hardware,

01:04:10.579 --> 01:04:11.746
then you can make it bigger.

01:04:14.945 --> 01:04:17.900
You mentioned that if you do
it with multiple modalities

01:04:17.900 --> 01:04:19.740
of data, [INAUDIBLE].

01:04:21.722 --> 01:04:22.889
How does that actually work?

01:04:22.889 --> 01:04:26.359
Do you leave the different
data as different token,

01:04:26.360 --> 01:04:29.220
or is it [INAUDIBLE]?

01:04:29.219 --> 01:04:31.349
No, so yeah, so you
take your image,

01:04:31.349 --> 01:04:33.239
and you apparently chop
them up into patches.

01:04:33.239 --> 01:04:35.369
So there's the first
thousand tokens or whatever.

01:04:35.369 --> 01:04:37.139
And now, I have a special--

01:04:37.139 --> 01:04:40.934
so radar could be also,
but I don't actually

01:04:40.934 --> 01:04:43.920
want to make a
representation of radar.

01:04:43.920 --> 01:04:46.075
But you just need to
chop it up and enter it.

01:04:46.074 --> 01:04:47.699
And then you have to
encode it somehow.

01:04:47.699 --> 01:04:48.659
Like the transformer
needs to know

01:04:48.659 --> 01:04:49.951
that they're coming from radar.

01:04:49.952 --> 01:04:52.290
So you create a special--

01:04:52.289 --> 01:04:55.706
you have some kind of a
special token of that to--

01:04:55.706 --> 01:04:57.289
these radar tokens
are what's slightly

01:04:57.289 --> 01:04:58.759
different in the
representation, and it's

01:04:58.760 --> 01:05:00.050
learnable by gradient descent.

01:05:00.050 --> 01:05:03.500
And like vehicle
information would also

01:05:03.500 --> 01:05:07.920
come in with a special embedded
token that can be learned.

01:05:07.920 --> 01:05:09.289
So--

01:05:09.289 --> 01:05:11.654
So how do you line
those before really--

01:05:11.655 --> 01:05:12.830
Actually, but you don't.

01:05:12.829 --> 01:05:13.938
It's all just a set.

01:05:13.938 --> 01:05:14.480
And there's--

01:05:14.480 --> 01:05:18.744
Even the [INAUDIBLE]

01:05:18.744 --> 01:05:20.869
Yeah, it's all just a set,
but you can positionally

01:05:20.869 --> 01:05:23.190
encode these sets if you want.

01:05:23.190 --> 01:05:26.150
So positional
encoding means you can

01:05:26.150 --> 01:05:28.130
hardwire, for example,
the coordinates

01:05:28.130 --> 01:05:29.510
like using [INAUDIBLE].

01:05:29.510 --> 01:05:31.310
You can hardwire
that, but it's better

01:05:31.309 --> 01:05:33.380
if you don't hardwire
the position.

01:05:33.380 --> 01:05:34.768
It's just a vector
that is always

01:05:34.768 --> 01:05:35.934
hanging out the dislocation.

01:05:35.934 --> 01:05:37.909
Whatever content is
there, it just adds on it.

01:05:37.909 --> 01:05:39.289
And this vector is
trainable by background.

01:05:39.289 --> 01:05:40.164
That's how you do it.

01:05:43.458 --> 01:05:43.958
Good point.

01:05:43.958 --> 01:05:45.994
I don't really like
the [INAUDIBLE]..

01:05:48.735 --> 01:05:51.400
They seem to work, but it
seems like they're sometimes

01:05:51.400 --> 01:06:08.867
[INAUDIBLE]

01:06:08.867 --> 01:06:10.659
I'm not sure if I
understand your question.

01:06:10.659 --> 01:06:11.295
[LAUGHTER]

01:06:11.295 --> 01:06:12.700
So I mean the
positional encoders

01:06:12.699 --> 01:06:14.619
like they're actually like not--

01:06:14.619 --> 01:06:16.969
OK, so they have very little
inductive bias or something

01:06:16.969 --> 01:06:17.469
like that.

01:06:17.469 --> 01:06:19.636
They're just vectors hanging
out in location always,

01:06:19.637 --> 01:06:23.900
and you're trying to help
the network in some way.

01:06:23.900 --> 01:06:28.710
And I think the
intuition is good,

01:06:28.710 --> 01:06:30.490
but if you have
enough data, usually,

01:06:30.489 --> 01:06:33.699
trying to mess with
it is a bad thing.

01:06:33.699 --> 01:06:35.199
Trying to enter
knowledge when you

01:06:35.199 --> 01:06:36.574
have enough
knowledge in the data

01:06:36.574 --> 01:06:38.164
set itself is not
usually productive.

01:06:38.164 --> 01:06:40.164
So it all really depends
on what scale you want.

01:06:40.164 --> 01:06:41.949
If you have infinity
data, then you actually

01:06:41.949 --> 01:06:43.059
want to encode less and less.

01:06:43.059 --> 01:06:44.299
That turns out to work better.

01:06:44.300 --> 01:06:46.269
And if you have very little
data, then actually, you do

01:06:46.269 --> 01:06:47.230
want to encode some biases.

01:06:47.230 --> 01:06:49.179
And maybe if you have a
much smaller data set, then

01:06:49.179 --> 01:06:50.596
maybe convolutions
are a good idea

01:06:50.597 --> 01:06:55.269
because you actually have this
bias coming from your filters.

01:06:55.269 --> 01:06:58.969
But I think-- so the transformer
is extremely general,

01:06:58.969 --> 01:07:01.230
but there are ways to
mess with the encodings

01:07:01.230 --> 01:07:02.271
to put in more structure.

01:07:02.271 --> 01:07:05.039
Like you could, for example,
encode [INAUDIBLE] and fix it,

01:07:05.039 --> 01:07:07.164
or you could actually go
to the attention mechanism

01:07:07.164 --> 01:07:10.831
and say, OK, if my image
is chopped up into patches,

01:07:10.831 --> 01:07:13.039
this patch can only communicate
to this neighborhood.

01:07:13.039 --> 01:07:15.170
And you just do that in
the attention matrix,

01:07:15.170 --> 01:07:18.152
you just mask out whatever
you don't want to communicate.

01:07:18.152 --> 01:07:19.610
And so people really
play with this

01:07:19.610 --> 01:07:22.724
because the full
attention is inefficient.

01:07:22.724 --> 01:07:25.159
So they will intersperse,
for example, layers

01:07:25.159 --> 01:07:26.869
that only communicate
in little patches

01:07:26.869 --> 01:07:28.639
and then layers that
communicate globally.

01:07:28.639 --> 01:07:30.679
And they will do all
kinds of tricks like that.

01:07:30.679 --> 01:07:33.922
So you can slowly bring
in more inductive bias.

01:07:33.922 --> 01:07:35.630
You would do it, but
the inductive biases

01:07:35.630 --> 01:07:38.990
are like they're factored out
from the core transformer.

01:07:38.989 --> 01:07:41.957
And they are factored out,
and the interconnectivity

01:07:41.958 --> 01:07:42.500
of the nodes.

01:07:42.500 --> 01:07:44.909
And they are factored
out in the positionally--

01:07:44.909 --> 01:07:49.657
and you can mess with
this for computation.

01:07:49.657 --> 01:08:01.067
[INAUDIBLE]

01:08:02.530 --> 01:08:06.407
So there's probably about 200
papers on this now if not more.

01:08:06.407 --> 01:08:07.990
They're kind of hard
to keep track of.

01:08:07.989 --> 01:08:10.119
Honestly, like my Safari
browser, which is-- oh,

01:08:10.119 --> 01:08:13.750
it's all up on my computer,
like 200 open tabs.

01:08:13.750 --> 01:08:20.579
But yes, I'm not
even sure if I want

01:08:20.579 --> 01:08:23.609
to pick my favorite honestly.

01:08:23.609 --> 01:08:29.904
Yeah, [INAUDIBLE]

01:08:42.600 --> 01:08:45.146
Maybe you can use a transformer
like that [INAUDIBLE]

01:08:45.145 --> 01:08:46.978
The other one that I
actually like even more

01:08:46.979 --> 01:08:49.289
is potentially, keep
the context length fixed

01:08:49.289 --> 01:08:53.086
but allow the network to
somehow use a scratch pad.

01:08:53.087 --> 01:08:55.545
And so the way this works is
you will teach the transformer

01:08:55.545 --> 01:08:57.869
somehow via examples
in [INAUDIBLE] hey,

01:08:57.869 --> 01:09:00.265
you actually have a scratch pad.

01:09:00.265 --> 01:09:01.890
Basically, you can't
remember too much.

01:09:01.890 --> 01:09:02.939
Your context line is finite.

01:09:02.939 --> 01:09:04.200
But you can use a scratch pad.

01:09:04.199 --> 01:09:06.426
And you do that by emitting
a start scratch pad,

01:09:06.426 --> 01:09:08.759
and then writing whatever you
want to remember, and then

01:09:08.760 --> 01:09:10.079
end scratch pad.

01:09:10.079 --> 01:09:12.750
And then you continue
with whatever you want.

01:09:12.750 --> 01:09:14.345
And then later
when it's decoding,

01:09:14.345 --> 01:09:15.720
you actually have
special objects

01:09:15.720 --> 01:09:18.090
that when you detect
start scratch pad,

01:09:18.090 --> 01:09:19.739
you will like save
whatever it puts

01:09:19.739 --> 01:09:22.639
in there in like external thing
and allow it to attend over it.

01:09:22.640 --> 01:09:25.140
So basically, you can teach the
transformer just dynamically

01:09:25.140 --> 01:09:27.479
because it's so meta-learned.

01:09:27.479 --> 01:09:30.060
You can teach it dynamically
to use other gizmos and gadgets

01:09:30.060 --> 01:09:31.927
and allow it to expand
its memory that way

01:09:31.926 --> 01:09:32.759
if that makes sense.

01:09:32.760 --> 01:09:35.533
It's just like human learning
to use a notepad, right.

01:09:35.533 --> 01:09:37.200
You don't have to
keep it in your brain.

01:09:37.199 --> 01:09:39.119
So keeping things in your
brain is like the context line

01:09:39.119 --> 01:09:39.994
from the transformer.

01:09:39.994 --> 01:09:42.119
But maybe we can just
give it a notebook.

01:09:42.119 --> 01:09:45.149
And then it can query the
notebook, and read from it,

01:09:45.149 --> 01:09:46.396
and write to it.

01:09:46.396 --> 01:09:48.689
[INAUDIBLE] transformer to
plug in another transformer.

01:09:48.689 --> 01:09:50.645
[LAUGHTER]

01:09:53.090 --> 01:09:58.140
[INAUDIBLE]

01:10:09.140 --> 01:10:10.520
I don't know if I detected that.

01:10:10.520 --> 01:10:12.853
I feel like-- did you feel
like there was more than just

01:10:12.853 --> 01:10:14.720
a long prompt that's unfolding?

01:10:14.720 --> 01:10:19.930
Yeah, [INAUDIBLE]

01:10:19.930 --> 01:10:22.960
I didn't try extensively, but
I did see a [INAUDIBLE] event.

01:10:22.960 --> 01:10:25.270
And I felt like the block
size was just moved.

01:10:28.162 --> 01:10:28.829
Maybe I'm wrong.

01:10:28.829 --> 01:10:31.199
I don't actually know about
the internals of ChatGPT.

01:10:31.199 --> 01:10:33.085
We have two online questions.

01:10:33.085 --> 01:10:35.984
So one question is, "what do
you think about architecture

01:10:35.984 --> 01:10:38.889
[INAUDIBLE]?"

01:10:38.890 --> 01:10:39.510
S4?

01:10:39.510 --> 01:10:40.930
S4.

01:10:40.930 --> 01:10:41.430
I'm sorry.

01:10:41.430 --> 01:10:42.670
I don't know S4.

01:10:42.670 --> 01:10:45.340
Which one is this one?

01:10:45.340 --> 01:10:47.710
The second question, this
one's a personal question.

01:10:47.710 --> 01:10:49.725
"What are you going
to work on next?"

01:10:49.725 --> 01:10:51.364
[INAUDIBLE]

01:10:51.364 --> 01:10:53.739
I mean, so right now, I'm
working on things like nanoGPT.

01:10:53.739 --> 01:10:54.939
Where is nanoGPT?

01:10:58.765 --> 01:11:01.140
I mean, I'm going basically
slightly from computer vision

01:11:01.140 --> 01:11:03.869
and like computer
vision-based products, do

01:11:03.869 --> 01:11:05.309
a little bit in language domain.

01:11:05.310 --> 01:11:06.472
Where's ChatGPT?

01:11:06.471 --> 01:11:07.416
OK, nanoGPT.

01:11:07.416 --> 01:11:10.215
So originally, I had minGPT,
which I rewrote to nanoGPT.

01:11:10.215 --> 01:11:11.819
And I'm working on this.

01:11:11.819 --> 01:11:14.219
I'm trying to reproduce
GPTs, and I mean,

01:11:14.220 --> 01:11:16.050
I think something
like ChatGPT, I think,

01:11:16.050 --> 01:11:17.970
incrementally improved
in a product fashion

01:11:17.970 --> 01:11:19.980
would be extremely interesting.

01:11:19.979 --> 01:11:23.009
And I think a lot
of people feel it,

01:11:23.010 --> 01:11:24.960
and that's why it went so wide.

01:11:24.960 --> 01:11:28.020
So I think there's
something like a Google plus

01:11:28.020 --> 01:11:31.960
plus plus to build that I
think is more interesting.

01:11:31.960 --> 01:11:34.649
Shall we give our speaker
a round of applause?
