1 00:00:05,990 --> 00:00:06,629 Hi, everyone. 2 00:00:06,629 --> 00:00:09,327 Welcome to CS 25 Transformers United V2. 3 00:00:09,327 --> 00:00:11,119 This was a course that was held at Stanford 4 00:00:11,119 --> 00:00:13,264 in the winter of 2023. 5 00:00:13,265 --> 00:00:14,839 This course is not about robots that 6 00:00:14,839 --> 00:00:17,324 can transform into cars as this picture might suggest. 7 00:00:17,324 --> 00:00:18,949 Rather, it's about deep learning models 8 00:00:18,949 --> 00:00:21,199 that have taken the world by storm 9 00:00:21,199 --> 00:00:23,439 and have revolutionized the field of AI and others. 10 00:00:23,440 --> 00:00:25,190 Starting from natural language processing, 11 00:00:25,190 --> 00:00:27,560 transformers have been applied all over, 12 00:00:27,559 --> 00:00:30,320 computer vision, reinforcement learning, biology, robotics, 13 00:00:30,320 --> 00:00:31,684 et cetera. 14 00:00:31,684 --> 00:00:34,100 We have an exciting set of videos lined up for you 15 00:00:34,100 --> 00:00:37,719 with some truly fascinating speakers, talks, presenting 16 00:00:37,719 --> 00:00:39,094 how they're applying transformers 17 00:00:39,094 --> 00:00:41,494 to the research in different fields and areas. 18 00:00:44,070 --> 00:00:47,700 We hope you'll enjoy and learn from these videos. 19 00:00:47,700 --> 00:00:52,130 So without any further ado, let's get started. 20 00:00:52,130 --> 00:00:54,760 This is a purely introductory lecture. 21 00:00:54,759 --> 00:00:58,750 And we'll go into the building blocks of transformers. 22 00:00:58,750 --> 00:01:03,530 So first, let's start with introducing the instructors. 23 00:01:03,530 --> 00:01:06,109 So for me, I'm currently on a temporary deferral from the PhD 24 00:01:06,109 --> 00:01:09,200 program, and I'm leading AI at a robotics startup, Collaborative 25 00:01:09,200 --> 00:01:13,579 Robotics, that are working on some general purpose robots, 26 00:01:13,579 --> 00:01:14,929 somewhat like [INAUDIBLE]. 27 00:01:14,930 --> 00:01:18,560 And I'm very passionate about robotics and building FSG 28 00:01:18,560 --> 00:01:19,513 learning algorithms. 29 00:01:19,513 --> 00:01:21,680 My research interests are in reinforcement learning, 30 00:01:21,680 --> 00:01:23,930 computer vision, and remodeling, and I 31 00:01:23,930 --> 00:01:25,820 have a bunch of publications in robotics, 32 00:01:25,819 --> 00:01:28,357 autonomous driving, and other areas. 33 00:01:28,358 --> 00:01:29,525 My undergrad was at Cornell. 34 00:01:29,525 --> 00:01:33,850 If someone is from Cornell, so nice to [INAUDIBLE].. 35 00:01:33,849 --> 00:01:37,209 So I'm Stephen, currently a first-year CS PhD here. 36 00:01:37,209 --> 00:01:40,609 Previously did my master's at CMU and undergrad at Waterloo. 37 00:01:40,609 --> 00:01:43,540 I'm mainly into NLP research, anything involving language 38 00:01:43,540 --> 00:01:45,880 and text, but more recently, I've 39 00:01:45,879 --> 00:01:48,789 been getting more into computer vision as well as [INAUDIBLE] 40 00:01:48,790 --> 00:01:51,520 And just some stuff I do for fun, a lot of music 41 00:01:51,519 --> 00:01:52,899 stuff, mainly piano. 42 00:01:52,900 --> 00:01:55,600 Some self-promo of what I post a lot on my Insta, YouTube, 43 00:01:55,599 --> 00:01:58,780 and TikTok, so if you guys want to check it out. 44 00:01:58,780 --> 00:02:01,719 My friends and I are also starting a Stanford piano club, 45 00:02:01,719 --> 00:02:04,539 so if anybody's interested, feel free to email 46 00:02:04,540 --> 00:02:07,060 or DM me for details. 47 00:02:07,060 --> 00:02:11,530 Other than that, martial arts, bodybuilding, and huge fan 48 00:02:11,530 --> 00:02:14,890 of k-dramas, anime, and occasional gamer. 49 00:02:14,889 --> 00:02:18,229 [LAUGHS] 50 00:02:18,729 --> 00:02:19,269 OK, cool. 51 00:02:19,270 --> 00:02:20,710 Yeah, so my name is Rylan. 52 00:02:20,710 --> 00:02:21,820 Instead of talking about myself, I just 53 00:02:21,819 --> 00:02:23,444 want to very briefly say that I'm super 54 00:02:23,444 --> 00:02:24,789 excited to take this class. 55 00:02:24,789 --> 00:02:26,409 I took it the last time-- sorry-- to teach this. 56 00:02:26,409 --> 00:02:26,740 Excuse me. 57 00:02:26,740 --> 00:02:28,360 I took it the last time I was offered. 58 00:02:28,360 --> 00:02:30,280 I had a bunch of fun. 59 00:02:30,280 --> 00:02:32,650 I thought we brought in a really great group of speakers 60 00:02:32,650 --> 00:02:33,150 last time. 61 00:02:33,150 --> 00:02:35,287 I'm super excited for this offering. 62 00:02:35,287 --> 00:02:37,120 And yeah, I'm thankful that you're all here, 63 00:02:37,120 --> 00:02:39,020 and I'm looking forward to a really fun quarter together. 64 00:02:39,020 --> 00:02:39,530 Thank you. 65 00:02:39,530 --> 00:02:42,129 Yeah, so fun fact, Rylan was the most outspoken student 66 00:02:42,129 --> 00:02:43,103 last year. 67 00:02:43,103 --> 00:02:45,520 And so if someone wants to become an instructor next year, 68 00:02:45,520 --> 00:02:46,762 you know what to do. 69 00:02:46,762 --> 00:02:49,954 [LAUGHTER] 70 00:02:50,870 --> 00:02:53,800 OK, cool. 71 00:02:53,800 --> 00:02:54,300 Let's see. 72 00:02:54,300 --> 00:02:56,510 OK, I think we have a few minutes. 73 00:02:56,509 --> 00:02:59,459 So what we hope you will learn in this class is, first of all, 74 00:02:59,460 --> 00:03:02,585 how do transformers work, how they 75 00:03:02,585 --> 00:03:04,103 are being applied, just beyond NLP, 76 00:03:04,103 --> 00:03:06,020 and nowadays, like they are pretty [INAUDIBLE] 77 00:03:06,020 --> 00:03:10,290 them everywhere in AI machine learning. 78 00:03:10,289 --> 00:03:12,539 And what are some new and interesting directions 79 00:03:12,539 --> 00:03:14,359 of research in these topics. 80 00:03:17,759 --> 00:03:19,724 Cool, so this class is just an introductory. 81 00:03:19,724 --> 00:03:22,215 So we're just talking about the basics of transformers, 82 00:03:22,215 --> 00:03:24,930 introducing them, talking about the self-attention mechanism 83 00:03:24,930 --> 00:03:26,580 on which they're founded. 84 00:03:26,580 --> 00:03:30,870 And we'll do a deep dive more on models like BERT 85 00:03:30,870 --> 00:03:32,250 to GPT, stuff like that. 86 00:03:32,250 --> 00:03:35,620 So with that, happy to get started. 87 00:03:35,620 --> 00:03:38,280 OK, so let me start with presenting the attention 88 00:03:38,280 --> 00:03:40,539 timeline. 89 00:03:40,539 --> 00:03:43,239 Attention all started with this one paper. 90 00:03:43,240 --> 00:03:46,270 [INAUDIBLE] by Vaswani et al in 2017. 91 00:03:46,270 --> 00:03:49,450 That was the beginning of transformers. 92 00:03:49,449 --> 00:03:51,489 Before that, we had the prehistoric error, 93 00:03:51,490 --> 00:03:55,840 where we had models like RNM, LSDMs, 94 00:03:55,840 --> 00:03:57,909 and simple attention mechanisms that didn't work 95 00:03:57,909 --> 00:03:59,949 or [INAUDIBLE]. 96 00:03:59,949 --> 00:04:02,994 Starting 2017, we saw this explosion of transformers 97 00:04:02,995 --> 00:04:07,180 into NLP, where people started using it for everything. 98 00:04:07,180 --> 00:04:08,680 I even heard this quote from Google. 99 00:04:08,680 --> 00:04:10,597 It's like our performance increased every time 100 00:04:10,597 --> 00:04:11,770 we [INAUDIBLE] 101 00:04:11,770 --> 00:04:13,183 [CHUCKLES] 102 00:04:15,069 --> 00:04:17,098 For the [INAUDIBLE] after 2018 to 2020, 103 00:04:17,098 --> 00:04:18,639 we saw this explosion of transformers 104 00:04:18,639 --> 00:04:23,500 into other fields like vision, a bunch of other stuff, 105 00:04:23,500 --> 00:04:25,990 and like biology as a whole. 106 00:04:25,990 --> 00:04:28,329 And in last year, 2021 was the start 107 00:04:28,329 --> 00:04:31,224 of the generative era, where we got a lot of genetic modeling, 108 00:04:31,225 --> 00:04:35,350 started models like Codex, GPT, DALL-E, 109 00:04:35,350 --> 00:04:37,360 stable diffusions, or a lot of things 110 00:04:37,360 --> 00:04:40,330 happening in genetic modeling. 111 00:04:40,329 --> 00:04:44,229 And we started scaling up in AI. 112 00:04:44,230 --> 00:04:45,490 And now, the present. 113 00:04:45,490 --> 00:04:49,269 So this is 2022 and the startup in '23. 114 00:04:49,269 --> 00:04:53,259 And now we have models like ChatGPT, Whisperer, 115 00:04:53,259 --> 00:04:54,550 a bunch of others. 116 00:04:54,550 --> 00:04:57,250 And we're scaling onwards without splitting up, 117 00:04:57,250 --> 00:04:58,810 so that's great. 118 00:04:58,810 --> 00:05:01,649 So that's the future. 119 00:05:01,649 --> 00:05:06,939 So going more into this, so once there were RNNs. 120 00:05:06,939 --> 00:05:10,829 So we had Seq2Seq models, LSTMs, GRU. 121 00:05:10,829 --> 00:05:13,839 What worked there was that they were good at encoding history, 122 00:05:13,839 --> 00:05:17,064 but what did not work was they didn't encode long sequences 123 00:05:17,064 --> 00:05:21,649 and they were very bad at encoding context. 124 00:05:21,649 --> 00:05:24,569 So consider this example. 125 00:05:24,569 --> 00:05:27,529 Consider trying to predict the last word in the text, 126 00:05:27,529 --> 00:05:29,329 "I grew up in France, dot, dot, dot. 127 00:05:29,329 --> 00:05:31,250 I speak fluent Dutch." 128 00:05:31,250 --> 00:05:33,740 Here, you need to understand the context for it 129 00:05:33,740 --> 00:05:36,470 to predict French, and attention mechanism 130 00:05:36,470 --> 00:05:39,425 is very good at that, whereas if they're just using LSDMs, 131 00:05:39,425 --> 00:05:42,350 it doesn't here work that well. 132 00:05:42,350 --> 00:05:46,400 Another thing transformers are good at is, 133 00:05:46,399 --> 00:05:50,149 more based on content, is also context prediction 134 00:05:50,149 --> 00:05:52,729 is like finding attention maps. 135 00:05:52,730 --> 00:05:56,450 If I have something like a word like it, 136 00:05:56,449 --> 00:05:57,979 what noun does it correlate to. 137 00:05:57,980 --> 00:06:01,759 And we can give a property attention 138 00:06:01,759 --> 00:06:05,240 on one of the possible activations. 139 00:06:05,240 --> 00:06:10,360 And this works better than existing mechanisms. 140 00:06:10,360 --> 00:06:16,465 OK, so where we were in 2021, we were on the verge of takeoff. 141 00:06:16,464 --> 00:06:18,839 We were starting to realize the potential of transformers 142 00:06:18,839 --> 00:06:20,879 in different fields. 143 00:06:20,879 --> 00:06:23,115 We solved a lot of long sequence problems 144 00:06:23,115 --> 00:06:26,340 like protein folding, AlphaFold, offline RL. 145 00:06:28,860 --> 00:06:31,512 We started to see few-shots, zero-shot generalization. 146 00:06:31,512 --> 00:06:34,425 We saw multimodal tasks and applications 147 00:06:34,425 --> 00:06:36,300 like generating images from language. 148 00:06:36,300 --> 00:06:40,997 So that was DALL-E. And it feels like [INAUDIBLE].. 149 00:06:43,865 --> 00:06:45,639 And this was also a talk on transformers 150 00:06:45,639 --> 00:06:48,610 that you can watch on YouTube. 151 00:06:48,610 --> 00:06:51,129 Yeah, cool. 152 00:06:51,129 --> 00:06:55,269 And this is where we were going from 2021 to 2022, 153 00:06:55,269 --> 00:06:58,814 which is we have gone from the version of [INAUDIBLE] 154 00:06:58,814 --> 00:07:00,564 And now, we are seeing unique applications 155 00:07:00,564 --> 00:07:03,745 in audio generation, art, music, storytelling. 156 00:07:03,745 --> 00:07:05,620 We are starting to see these new capabilities 157 00:07:05,620 --> 00:07:08,379 like commonsense, logical reasoning, 158 00:07:08,379 --> 00:07:09,879 mathematical reasoning. 159 00:07:09,879 --> 00:07:12,819 We are also able to now get human enlightenment 160 00:07:12,819 --> 00:07:13,949 and interaction. 161 00:07:13,949 --> 00:07:15,699 They're able to use reinforcement learning 162 00:07:15,699 --> 00:07:16,689 and human feedback. 163 00:07:16,689 --> 00:07:19,457 That's how ChatGPT is trained to perform really good. 164 00:07:19,458 --> 00:07:21,250 We have a lot of mechanisms for controlling 165 00:07:21,250 --> 00:07:24,370 toxicity bias and ethics now. 166 00:07:24,370 --> 00:07:26,110 And there are a lot of also, a lot 167 00:07:26,110 --> 00:07:30,530 of developments in other areas like diffusion models. 168 00:07:30,529 --> 00:07:33,319 Cool. 169 00:07:33,319 --> 00:07:35,611 So the future is a spaceship, and we are all 170 00:07:35,612 --> 00:07:36,320 excited about it. 171 00:07:39,401 --> 00:07:40,985 And there's a lot of more applications 172 00:07:40,985 --> 00:07:44,750 that we can enable, and it'll be great 173 00:07:44,750 --> 00:07:47,689 if you can see transformers also up there. 174 00:07:47,689 --> 00:07:49,939 One big example is video understanding and generation. 175 00:07:49,939 --> 00:07:51,981 That is something that everyone is interested in, 176 00:07:51,982 --> 00:07:53,900 and I'm hoping we'll see a lot of models 177 00:07:53,899 --> 00:07:59,839 in this area this year, also, finance, business. 178 00:07:59,839 --> 00:08:02,750 I'll be very excited to see GPT author a novel, 179 00:08:02,750 --> 00:08:04,970 but we need to solve very long sequence modeling. 180 00:08:04,970 --> 00:08:07,700 And most transformer models are still 181 00:08:07,699 --> 00:08:09,925 limited to 4,000 tokens or something like that. 182 00:08:09,925 --> 00:08:13,879 So we need to make them generalize much more 183 00:08:13,879 --> 00:08:17,255 better on long sequences. 184 00:08:17,255 --> 00:08:19,399 We also want to have generalized agents 185 00:08:19,399 --> 00:08:27,879 that can do a lot of multitask, a multi-input predictions 186 00:08:27,879 --> 00:08:28,750 like Gato. 187 00:08:28,750 --> 00:08:31,660 And so I think we will see more of that, too. 188 00:08:31,660 --> 00:08:37,240 And finally, we also want domain specific models. 189 00:08:37,240 --> 00:08:39,490 So you might want a GPT model, let's 190 00:08:39,490 --> 00:08:41,230 put it like maybe your health. 191 00:08:41,230 --> 00:08:43,129 So that could be like a DoctorGPT model. 192 00:08:43,129 --> 00:08:45,100 You might have a LawyerGPT model that's 193 00:08:45,100 --> 00:08:46,279 trained on only law data. 194 00:08:46,279 --> 00:08:49,209 So currently, we have GPT models that are trained on everything. 195 00:08:49,210 --> 00:08:51,730 But we might start to see more niche models that 196 00:08:51,730 --> 00:08:53,050 are good at one task. 197 00:08:53,049 --> 00:08:55,000 And we could have a mixture of experts, 198 00:08:55,000 --> 00:08:57,190 so it's like, you can think this is a-- 199 00:08:57,190 --> 00:08:58,760 how you'd normally consult an expert, 200 00:08:58,759 --> 00:09:00,220 you'll have expert AI models. 201 00:09:00,220 --> 00:09:02,887 And you can go to a different AI model for your different needs. 202 00:09:05,049 --> 00:09:07,269 There are still a lot of missing ingredients 203 00:09:07,269 --> 00:09:10,105 to make this all successful. 204 00:09:10,105 --> 00:09:12,414 The first of all is external memory. 205 00:09:12,414 --> 00:09:15,144 We are already starting to see this with the models 206 00:09:15,144 --> 00:09:18,519 like ChatGPT, where the inflections are short-lived. 207 00:09:18,519 --> 00:09:20,710 There's no long-term memory, and they 208 00:09:20,710 --> 00:09:23,410 don't have ability to remember or store 209 00:09:23,409 --> 00:09:25,969 conversations for long-term. 210 00:09:25,970 --> 00:09:29,980 And this is something you want to fix. 211 00:09:29,980 --> 00:09:32,779 Second is reducing the computation complexity. 212 00:09:32,779 --> 00:09:36,159 So attention mechanism is quadratic over the sequence 213 00:09:36,159 --> 00:09:37,689 length, which is slow. 214 00:09:37,690 --> 00:09:40,450 And we want to reduce it and make it faster. 215 00:09:42,855 --> 00:09:44,230 Another thing we want to do is we 216 00:09:44,230 --> 00:09:46,355 want to enhance the controllability of these models 217 00:09:46,355 --> 00:09:48,759 like a lot of these models can be stochastic. 218 00:09:48,759 --> 00:09:51,009 And we want to be able to control what sort of outputs 219 00:09:51,009 --> 00:09:52,307 we get from them. 220 00:09:52,307 --> 00:09:54,100 And you might have experienced the ChatGPT, 221 00:09:54,100 --> 00:09:56,913 if you just refresh, you get different output each time. 222 00:09:56,913 --> 00:09:59,080 But you might want to have a mechanism that controls 223 00:09:59,080 --> 00:10:01,180 what sort of things you get. 224 00:10:01,179 --> 00:10:04,239 And finally, we want to align our state of art language 225 00:10:04,240 --> 00:10:06,200 models with how the human brain works. 226 00:10:06,200 --> 00:10:09,280 And we are seeing the surge, but we still 227 00:10:09,279 --> 00:10:12,009 need more research on seeing how they can make more informed. 228 00:10:12,009 --> 00:10:14,460 Thank you. 229 00:10:14,460 --> 00:10:16,820 Great, hi. 230 00:10:16,820 --> 00:10:18,270 Yes, I'm excited to be here. 231 00:10:18,269 --> 00:10:21,079 I live very nearby, so I got the invites to come to class. 232 00:10:21,080 --> 00:10:23,500 And I was like, OK, I'll just walk over. 233 00:10:23,500 --> 00:10:25,375 But then I spent like 10 hours on the slides, 234 00:10:25,375 --> 00:10:28,250 so it wasn't as simple. 235 00:10:28,250 --> 00:10:30,710 So yeah, I'm going to talk about transformers. 236 00:10:30,710 --> 00:10:32,620 I'm going to skip the first two over there. 237 00:10:32,620 --> 00:10:34,139 I'm not going to talk about those. 238 00:10:34,139 --> 00:10:36,389 We'll talk about that one just to simplify the lecture 239 00:10:36,389 --> 00:10:39,379 since we don't have time. 240 00:10:39,379 --> 00:10:41,600 OK, so I wanted to provide a little bit of context 241 00:10:41,600 --> 00:10:44,336 on why does this transformers class even exist. 242 00:10:44,336 --> 00:10:45,919 So a little bit of historical context. 243 00:10:45,919 --> 00:10:47,569 I feel like Bilbo over there. 244 00:10:47,570 --> 00:10:50,712 I joined like telling you guys about this. 245 00:10:50,711 --> 00:10:52,669 I don't know if you guys saw Lord of the Rings. 246 00:10:52,669 --> 00:10:56,860 And basically, I joined AI in roughly 2012, the full course, 247 00:10:56,860 --> 00:10:58,009 so maybe a decade ago. 248 00:10:58,009 --> 00:10:59,509 And back then, you wouldn't even say 249 00:10:59,509 --> 00:11:00,809 that you joined AI by the way. 250 00:11:00,809 --> 00:11:02,449 That was like a dirty word. 251 00:11:02,450 --> 00:11:04,535 Now, it's OK to talk about, but back then, it 252 00:11:04,534 --> 00:11:05,659 was not even deep learning. 253 00:11:05,659 --> 00:11:06,500 It was machine learning. 254 00:11:06,500 --> 00:11:08,625 That was the term we would use if you were serious. 255 00:11:08,625 --> 00:11:11,960 But now, now, AI is OK to use, I think. 256 00:11:11,960 --> 00:11:13,437 So basically, do you even realize 257 00:11:13,437 --> 00:11:15,019 how lucky you are potentially entering 258 00:11:15,019 --> 00:11:17,419 this area in roughly 2023? 259 00:11:17,419 --> 00:11:20,269 So back then, in 2011 or so when I was working specifically 260 00:11:20,269 --> 00:11:25,960 on computer vision, your pipeline's looked like this. 261 00:11:25,960 --> 00:11:28,350 So you wanted to classify some images, 262 00:11:28,350 --> 00:11:30,850 you would go to a paper, and I think this is representative. 263 00:11:30,850 --> 00:11:32,932 You would have three pages in the paper describing 264 00:11:32,932 --> 00:11:34,986 all kinds of a zoo, of kitchen sink, 265 00:11:34,986 --> 00:11:36,819 of different kinds of features, descriptors. 266 00:11:36,820 --> 00:11:38,853 And you would go to a poster session 267 00:11:38,852 --> 00:11:40,269 and in computer vision conference, 268 00:11:40,269 --> 00:11:41,980 and everyone would have their favorite feature descriptor 269 00:11:41,980 --> 00:11:42,610 that they're proposing. 270 00:11:42,610 --> 00:11:44,200 And it's totally ridiculous, and you 271 00:11:44,200 --> 00:11:45,550 would take notes on which one you should incorporate 272 00:11:45,549 --> 00:11:48,339 into your pipeline because you would extract all of them, 273 00:11:48,340 --> 00:11:49,882 and then you would put an SVM on top. 274 00:11:49,881 --> 00:11:51,048 So that's what you would do. 275 00:11:51,048 --> 00:11:52,000 So there's two pages. 276 00:11:52,000 --> 00:11:54,082 Make sure you get your [? Spar ?] SIFT histograms, 277 00:11:54,082 --> 00:11:56,110 your SSIMs, your color histograms, textiles, 278 00:11:56,110 --> 00:11:57,340 tiny images. 279 00:11:57,340 --> 00:11:59,649 And don't forget the geometry specific histograms. 280 00:11:59,649 --> 00:12:02,225 All of them have basically complicated code by themselves. 281 00:12:02,225 --> 00:12:04,600 So you're collecting code from everywhere and running it, 282 00:12:04,600 --> 00:12:06,430 and it was a total nightmare. 283 00:12:06,429 --> 00:12:10,989 So on top of that, it also didn't work. 284 00:12:10,990 --> 00:12:11,570 [LAUGHTER] 285 00:12:11,570 --> 00:12:14,440 So this would be, I think, it represents the prediction 286 00:12:14,440 --> 00:12:15,305 from that time. 287 00:12:15,304 --> 00:12:17,679 You would just get predictions like this once in a while, 288 00:12:17,679 --> 00:12:19,329 and you'd be like, you just shrug your shoulders 289 00:12:19,330 --> 00:12:20,955 like that just happens once in a while. 290 00:12:20,955 --> 00:12:23,680 Today, you would be looking for a bug. 291 00:12:23,679 --> 00:12:30,639 And worse than that, every single chunk of AI 292 00:12:30,639 --> 00:12:32,866 had their own completely separate vocabulary 293 00:12:32,866 --> 00:12:33,699 that they work with. 294 00:12:33,700 --> 00:12:36,810 So if you go to NLP papers, those papers 295 00:12:36,809 --> 00:12:38,059 would be completely different. 296 00:12:38,059 --> 00:12:40,101 So you're reading the NLP paper, and you're like, 297 00:12:40,101 --> 00:12:42,490 what is this part of speech tagging, 298 00:12:42,490 --> 00:12:44,605 morphological analysis, and tactic parsing, 299 00:12:44,605 --> 00:12:46,029 co-reference resolution? 300 00:12:46,029 --> 00:12:48,189 What is MPBTKJ? 301 00:12:48,190 --> 00:12:49,190 And you're confused. 302 00:12:49,190 --> 00:12:51,430 So the vocabulary and everything was completely different. 303 00:12:51,429 --> 00:12:52,971 And you couldn't read papers, I would 304 00:12:52,971 --> 00:12:55,100 say, across different areas. 305 00:12:55,100 --> 00:12:56,590 So now, that changed a little bit 306 00:12:56,590 --> 00:13:02,379 starting 2012 when Al Krizhevsky and colleagues basically 307 00:13:02,379 --> 00:13:05,439 demonstrated that if you scale a large neural network 308 00:13:05,440 --> 00:13:08,460 on large data set, you can get very strong performance. 309 00:13:08,460 --> 00:13:10,960 And so up till then, there was a lot of focus on algorithms. 310 00:13:10,960 --> 00:13:13,330 But this showed that actually neural nets scale very well. 311 00:13:13,330 --> 00:13:15,160 So you need to now worry about compute and data, 312 00:13:15,159 --> 00:13:16,159 and you can scale it up. 313 00:13:16,159 --> 00:13:17,329 It works pretty well. 314 00:13:17,330 --> 00:13:19,509 And then that recipe actually did copy paste 315 00:13:19,509 --> 00:13:21,519 across many areas of AI. 316 00:13:21,519 --> 00:13:23,740 So we start to see neural networks pop up everywhere 317 00:13:23,740 --> 00:13:25,768 since 2012. 318 00:13:25,768 --> 00:13:28,060 So we saw them in computer vision, and NLP, and speech, 319 00:13:28,059 --> 00:13:30,339 and translation in RL and so on. 320 00:13:30,340 --> 00:13:32,649 So everyone started to use the same kind of modeling 321 00:13:32,649 --> 00:13:33,985 toolkit, modeling framework. 322 00:13:33,985 --> 00:13:36,610 And now when you go to NLP, and you start reading papers there, 323 00:13:36,610 --> 00:13:38,710 in machine translation, for example, 324 00:13:38,710 --> 00:13:40,210 this is a sequence to sequence paper 325 00:13:40,210 --> 00:13:41,923 which we'll come back to in a bit. 326 00:13:41,923 --> 00:13:44,090 You start to read those papers, and you're like, OK, 327 00:13:44,090 --> 00:13:45,340 I can recognize these words. 328 00:13:45,340 --> 00:13:46,420 Like there's a neural network. 329 00:13:46,419 --> 00:13:47,419 There's some parameters. 330 00:13:47,419 --> 00:13:50,064 There's an optimizer, and it starts to read things 331 00:13:50,065 --> 00:13:50,950 that you know of. 332 00:13:50,950 --> 00:13:54,205 So that decreased tremendously the barrier to entry 333 00:13:54,205 --> 00:13:56,490 across the different areas. 334 00:13:56,490 --> 00:13:57,970 And then, I think, the big deal is 335 00:13:57,970 --> 00:14:00,317 that when the transformer came out in 2017, 336 00:14:00,317 --> 00:14:02,860 it's not even that just the tool kits and the neural networks 337 00:14:02,860 --> 00:14:05,529 were similar-- there's that literally the architectures 338 00:14:05,529 --> 00:14:07,480 converged to like one architecture that you 339 00:14:07,480 --> 00:14:10,180 copy paste across everything seemingly. 340 00:14:10,179 --> 00:14:12,724 So this was kind of an unassuming machine translation 341 00:14:12,725 --> 00:14:15,100 paper at the time, proposing to transformer architecture. 342 00:14:15,100 --> 00:14:17,965 But what we found since then is that you can just basically 343 00:14:17,965 --> 00:14:21,710 copy paste this architecture and use it everywhere. 344 00:14:21,710 --> 00:14:23,889 And what's changing is the details of the data, 345 00:14:23,889 --> 00:14:26,500 and the chunking of the data, and how you feed it in. 346 00:14:26,500 --> 00:14:28,085 And that's a caricature, but it's 347 00:14:28,085 --> 00:14:29,960 kind of like a correct first order statement. 348 00:14:29,960 --> 00:14:32,800 And so now, papers are even more similar looking 349 00:14:32,799 --> 00:14:34,849 because everyone's just using transformer. 350 00:14:34,850 --> 00:14:38,769 And so this convergence was remarkable to watch 351 00:14:38,769 --> 00:14:40,210 and unfolded over the last decade. 352 00:14:40,210 --> 00:14:42,340 And it's pretty crazy to me. 353 00:14:42,340 --> 00:14:44,038 What I find interesting is I think 354 00:14:44,038 --> 00:14:46,330 this is some kind of a hint that we're maybe converging 355 00:14:46,330 --> 00:14:48,080 to something that maybe the brain is doing 356 00:14:48,080 --> 00:14:50,560 because the brain is very homogeneous and uniform 357 00:14:50,559 --> 00:14:52,831 across the entire sheet of your cortex. 358 00:14:52,831 --> 00:14:54,789 And OK, maybe some of the details are changing, 359 00:14:54,789 --> 00:14:56,409 but those feel like hyperparameters 360 00:14:56,409 --> 00:14:57,490 like a transformer. 361 00:14:57,490 --> 00:14:59,560 But your auditory cortex and your visual cortex 362 00:14:59,559 --> 00:15:01,029 and everything else looks very similar. 363 00:15:01,029 --> 00:15:02,779 And so maybe we're converging to some kind 364 00:15:02,779 --> 00:15:06,100 of a uniform powerful learning algorithm here. 365 00:15:06,100 --> 00:15:09,060 Something like that, I think, is interesting and exciting. 366 00:15:09,059 --> 00:15:11,309 OK, so I want to talk about where the transformer came 367 00:15:11,309 --> 00:15:12,771 from briefly, historically. 368 00:15:12,772 --> 00:15:15,430 So I want to start in 2003. 369 00:15:15,429 --> 00:15:17,084 I like this paper quite a bit. 370 00:15:17,085 --> 00:15:21,190 It was the first popular application of neural networks 371 00:15:21,190 --> 00:15:22,690 to the problem of language modeling, 372 00:15:22,690 --> 00:15:24,398 so predicting in this case, the next word 373 00:15:24,398 --> 00:15:26,148 in the sequence, which allows you to build 374 00:15:26,148 --> 00:15:27,320 generative models over text. 375 00:15:27,320 --> 00:15:29,695 And in this case, they were using multi-layer perceptron, 376 00:15:29,695 --> 00:15:30,860 so very simple neural net. 377 00:15:30,860 --> 00:15:33,442 The neural nets took three words and predicted the probability 378 00:15:33,442 --> 00:15:36,000 distribution for the fourth word in a sequence. 379 00:15:36,000 --> 00:15:39,519 So this was well and good at this point. 380 00:15:39,519 --> 00:15:41,710 Now, over time, people started to apply this 381 00:15:41,710 --> 00:15:43,610 to machine translation. 382 00:15:43,610 --> 00:15:45,759 So that brings us to sequence to sequence paper 383 00:15:45,759 --> 00:15:48,009 from 2014 that was pretty influential, 384 00:15:48,009 --> 00:15:49,812 and the big problem here was OK, we 385 00:15:49,812 --> 00:15:52,269 don't just want to take three words and predict the fourth. 386 00:15:52,269 --> 00:15:55,329 We want to predict how to go from an English sentence 387 00:15:55,330 --> 00:15:56,830 to a French sentence. 388 00:15:56,830 --> 00:15:58,387 And the key problem was OK, you can 389 00:15:58,386 --> 00:16:00,969 have arbitrary number of words in English and arbitrary number 390 00:16:00,970 --> 00:16:03,040 of words in French, so how do you 391 00:16:03,039 --> 00:16:04,750 get an architecture that can process 392 00:16:04,750 --> 00:16:06,820 this variably sized input? 393 00:16:06,820 --> 00:16:10,330 And so here they used a LSDM, and there's basically 394 00:16:10,330 --> 00:16:16,160 two chunks of this, which are covered by the slack, by this. 395 00:16:16,159 --> 00:16:19,009 But basically have an encoder LSDM on the left, 396 00:16:19,009 --> 00:16:22,189 and it just consumes one word at a time 397 00:16:22,190 --> 00:16:24,230 and builds up a context of what it has read. 398 00:16:24,230 --> 00:16:26,899 And then that acts as a conditioning vector 399 00:16:26,899 --> 00:16:29,019 to the decoder RNN or LSDM. 400 00:16:29,019 --> 00:16:30,394 That basically goes chonk, chonk, 401 00:16:30,394 --> 00:16:32,299 chonk for the next word in a sequence, 402 00:16:32,299 --> 00:16:35,437 translating the English to French or something like that. 403 00:16:35,437 --> 00:16:37,730 Now, the big problem with this, that people identified, 404 00:16:37,730 --> 00:16:40,129 I think, very quickly and tried to resolve 405 00:16:40,129 --> 00:16:43,320 is that there's what's called this encoder bottleneck. 406 00:16:43,320 --> 00:16:46,400 So this entire English sentence that we are trying to condition 407 00:16:46,399 --> 00:16:48,289 on is packed into a single vector 408 00:16:48,289 --> 00:16:50,879 that goes from the encoder to the decoder. 409 00:16:50,879 --> 00:16:52,547 And so this is just too much information 410 00:16:52,547 --> 00:16:54,338 to potentially maintain in a single vector, 411 00:16:54,337 --> 00:16:55,579 and that didn't seem correct. 412 00:16:55,580 --> 00:16:57,455 And so people who are looking around for ways 413 00:16:57,455 --> 00:17:00,800 to alleviate the attention of the encoder bottleneck as it 414 00:17:00,799 --> 00:17:02,079 was called at the time. 415 00:17:02,080 --> 00:17:03,773 And so that brings us to this paper, 416 00:17:03,773 --> 00:17:05,690 Neural Machine Translation by Jointly Learning 417 00:17:05,690 --> 00:17:07,549 to Align and Translate. 418 00:17:07,549 --> 00:17:11,552 And here, just quoting from the abstract, "in this paper, 419 00:17:11,553 --> 00:17:13,720 we conjectured that the use of a fixed length vector 420 00:17:13,720 --> 00:17:15,553 is a bottleneck in improving the performance 421 00:17:15,553 --> 00:17:17,304 of the basic encoder-decoder architecture 422 00:17:17,304 --> 00:17:19,720 and propose to extend this by allowing 423 00:17:19,720 --> 00:17:21,700 the model to automatically soft search 424 00:17:21,700 --> 00:17:24,366 for parts of the source sentence that are relevant to predicting 425 00:17:24,366 --> 00:17:28,029 a target word without having to form 426 00:17:28,029 --> 00:17:30,049 these parts or hard segments exclusively." 427 00:17:30,049 --> 00:17:34,690 So this was a way to look back to the words that 428 00:17:34,690 --> 00:17:35,950 are coming from the encoder. 429 00:17:35,950 --> 00:17:38,390 And it was achieved using this soft search. 430 00:17:38,390 --> 00:17:42,250 So as you are decoding in the words 431 00:17:42,250 --> 00:17:44,172 here, while you are decoding them, 432 00:17:44,172 --> 00:17:45,880 you are allowed to look back at the words 433 00:17:45,880 --> 00:17:49,150 at the encoder via this soft attention mechanism proposed 434 00:17:49,150 --> 00:17:50,180 in this paper. 435 00:17:50,180 --> 00:17:52,735 And so this paper, I think, is the first time that I saw, 436 00:17:52,734 --> 00:17:55,689 basically, attention. 437 00:17:55,690 --> 00:17:58,990 So your context vector that comes from the encoder 438 00:17:58,990 --> 00:18:01,150 is a weighted sum of the hidden states 439 00:18:01,150 --> 00:18:05,470 of the words in the encoding. 440 00:18:05,470 --> 00:18:07,450 And then the weights of this sum come 441 00:18:07,450 --> 00:18:10,900 from a softmax that is based on these compatibilities 442 00:18:10,900 --> 00:18:13,300 between the current state as you're decoding 443 00:18:13,299 --> 00:18:15,325 and the hidden states generated by the encoder. 444 00:18:15,325 --> 00:18:17,200 And so this is the first time that really you 445 00:18:17,200 --> 00:18:22,059 start to look at it, and this is the current modern equations 446 00:18:22,059 --> 00:18:23,259 of the attention. 447 00:18:23,259 --> 00:18:25,509 And I think this was the first paper that I saw it in. 448 00:18:25,509 --> 00:18:27,670 It's the first time that there's a word 449 00:18:27,670 --> 00:18:32,029 attention used, as far as I know, to call this mechanism. 450 00:18:32,029 --> 00:18:34,480 So I actually tried to dig into the details of the history 451 00:18:34,480 --> 00:18:35,740 of the attention. 452 00:18:35,740 --> 00:18:38,518 So the first author here, Dzmitry, I 453 00:18:38,518 --> 00:18:40,059 had an email correspondence with him, 454 00:18:40,059 --> 00:18:41,440 and I basically sent him an email. 455 00:18:41,440 --> 00:18:43,000 I'm like, Dzmitry, this is really interesting. 456 00:18:43,000 --> 00:18:44,259 Just rumors have taken over. 457 00:18:44,259 --> 00:18:45,819 Where did you come up with the soft attention 458 00:18:45,819 --> 00:18:48,309 mechanism that ends up being the heart of the transformer? 459 00:18:48,309 --> 00:18:52,037 And to my surprise, he wrote me back this massive email, which 460 00:18:52,037 --> 00:18:52,995 was really fascinating. 461 00:18:52,994 --> 00:18:54,577 So this is an excerpt from that email. 462 00:18:57,119 --> 00:18:59,969 So basically, he talks about how he was looking for a way 463 00:18:59,970 --> 00:19:02,490 to avoid this bottleneck between the encoder and decoder. 464 00:19:02,490 --> 00:19:04,049 He had some ideas about cursors that 465 00:19:04,049 --> 00:19:06,809 traverse the sequences that didn't quite work out. 466 00:19:06,809 --> 00:19:08,909 And then here, "so one day, I had this thought 467 00:19:08,910 --> 00:19:10,701 that it would be nice to enable the decoder 468 00:19:10,701 --> 00:19:13,680 RNN to learn to search where to put the cursor in the source 469 00:19:13,680 --> 00:19:14,610 sequence. 470 00:19:14,609 --> 00:19:16,692 This was sort of inspired by translation exercises 471 00:19:16,692 --> 00:19:21,150 that learning English in my middle school involved. 472 00:19:21,150 --> 00:19:23,567 Your gaze shifts back and forth between source and target, 473 00:19:23,567 --> 00:19:24,692 sequence as you translate." 474 00:19:24,692 --> 00:19:27,150 So literally, I thought that this was kind of interesting, 475 00:19:27,150 --> 00:19:28,425 that he's not a native English speaker, 476 00:19:28,424 --> 00:19:31,079 and here, that gave him an edge in this machine translation 477 00:19:31,079 --> 00:19:34,519 that led to attention and then led to transformer. 478 00:19:34,519 --> 00:19:37,019 So that's really fascinating. 479 00:19:37,019 --> 00:19:38,670 "I expressed a soft search a softmax 480 00:19:38,670 --> 00:19:40,920 and then weighted averaging of the [INAUDIBLE] states. 481 00:19:40,920 --> 00:19:43,800 And basically, to my great excitement, 482 00:19:43,799 --> 00:19:45,750 this worked from the very first try." 483 00:19:45,750 --> 00:19:48,390 So really, I think, interesting piece of history. 484 00:19:48,390 --> 00:19:51,030 And as it later turned out that the name of RNN search 485 00:19:51,029 --> 00:19:54,059 was kind of lame, so the better name attention came 486 00:19:54,059 --> 00:19:57,179 from Yoshua on one of the final passes 487 00:19:57,180 --> 00:19:58,660 as they went over the paper. 488 00:19:58,660 --> 00:20:00,960 So maybe Attention is All You Need 489 00:20:00,960 --> 00:20:03,682 would have been called RNN Search is All You Need, 490 00:20:03,682 --> 00:20:05,099 but we have Yoshua Bengio to thank 491 00:20:05,099 --> 00:20:07,049 for a little bit of better name, I would say. 492 00:20:07,049 --> 00:20:08,940 So apparently, that's the history 493 00:20:08,940 --> 00:20:11,620 of this, which I thought was interesting. 494 00:20:11,619 --> 00:20:13,709 OK, so that brings us to 2017, which is Attention 495 00:20:13,710 --> 00:20:14,890 is All You Need. 496 00:20:14,890 --> 00:20:16,515 So this attention component, which 497 00:20:16,515 --> 00:20:19,020 in Dzmitry's paper was just one small segment, 498 00:20:19,019 --> 00:20:21,180 and there's all this bidirectional RNN, RNN 499 00:20:21,180 --> 00:20:25,235 and decoder, and this Attention All You Need paper is saying, 500 00:20:25,234 --> 00:20:26,944 OK, you can actually delete everything. 501 00:20:26,944 --> 00:20:28,319 What's making this work very well 502 00:20:28,319 --> 00:20:29,759 is just attention by itself. 503 00:20:29,759 --> 00:20:32,099 And so delete everything, keep attention. 504 00:20:32,099 --> 00:20:35,129 And then what's remarkable about this paper actually is usually, 505 00:20:35,130 --> 00:20:36,880 you see papers that are very incremental. 506 00:20:36,880 --> 00:20:39,810 They add one thing, and they show that it's better. 507 00:20:39,809 --> 00:20:41,309 But I feel like Attention is All You 508 00:20:41,309 --> 00:20:44,099 Need was like a mix of multiple things at the same time. 509 00:20:44,099 --> 00:20:46,379 They were combined in a very unique way, 510 00:20:46,380 --> 00:20:49,110 and then also achieve a very good local minimum 511 00:20:49,109 --> 00:20:50,554 in the architecture space. 512 00:20:50,555 --> 00:20:52,529 And so to me, this is really a landmark paper 513 00:20:52,529 --> 00:20:55,859 that is quite remarkable and, I think, 514 00:20:55,859 --> 00:20:58,649 had quite a lot of work behind the scenes. 515 00:20:58,650 --> 00:21:01,380 So delete all the RNN, just keep attention. 516 00:21:01,380 --> 00:21:03,562 Because attention operates over sets-- 517 00:21:03,561 --> 00:21:05,269 and I'm going to go to this in a second-- 518 00:21:05,269 --> 00:21:07,228 you now need to positionally encode your inputs 519 00:21:07,228 --> 00:21:10,240 because attention doesn't have the notion of space by itself. 520 00:21:14,684 --> 00:21:17,669 I have to be very careful. 521 00:21:17,670 --> 00:21:19,904 They adopted this residual network structure 522 00:21:19,904 --> 00:21:21,450 from resonance. 523 00:21:21,450 --> 00:21:24,470 They interspersed attention with multi-layer perceptrons. 524 00:21:24,470 --> 00:21:27,012 They used layer norms, which came from a different paper. 525 00:21:27,012 --> 00:21:29,429 They introduced the concept of multiple heads of attention 526 00:21:29,430 --> 00:21:30,870 that were applied in parallel. 527 00:21:30,869 --> 00:21:33,000 And they gave us, I think, like a fairly good set 528 00:21:33,000 --> 00:21:35,279 of hyperparameters that to this day are used. 529 00:21:35,279 --> 00:21:39,509 So the expansion factor in the multi-layer perceptron goes up 530 00:21:39,509 --> 00:21:40,397 by 4X-- 531 00:21:40,397 --> 00:21:41,939 and we'll go into a bit more detail-- 532 00:21:41,940 --> 00:21:43,230 and this 4X has stuck around. 533 00:21:43,230 --> 00:21:44,968 And I believe there's a number of papers 534 00:21:44,968 --> 00:21:47,009 that try to play with all kinds of little details 535 00:21:47,009 --> 00:21:50,730 of the transformer, and nothing sticks because this is actually 536 00:21:50,730 --> 00:21:51,450 quite good. 537 00:21:51,450 --> 00:21:54,930 The only thing to my knowledge that didn't stick 538 00:21:54,930 --> 00:21:56,820 was this reshuffling of the layer norms 539 00:21:56,819 --> 00:21:59,419 to go into the prenorm version where here you 540 00:21:59,420 --> 00:22:01,920 see the layer norms are after the multiheaded attention feed 541 00:22:01,920 --> 00:22:02,759 forward. 542 00:22:02,759 --> 00:22:04,277 They just put them before instead. 543 00:22:04,277 --> 00:22:06,360 So just reshuffling of layer norms, but otherwise, 544 00:22:06,359 --> 00:22:08,567 the TPTs and everything else that you're seeing today 545 00:22:08,567 --> 00:22:11,930 is basically the 2017 architecture from 5 years ago. 546 00:22:11,930 --> 00:22:13,680 And even though everyone is working on it, 547 00:22:13,680 --> 00:22:15,765 it's been proven remarkably resilient, 548 00:22:15,765 --> 00:22:17,280 which I think is real interesting. 549 00:22:17,279 --> 00:22:18,779 There are innovations that, I think, 550 00:22:18,779 --> 00:22:21,539 have been adopted also in positional encoding. 551 00:22:21,539 --> 00:22:24,000 It's more common to use different rotary and relative 552 00:22:24,000 --> 00:22:25,843 positional encoding and so on. 553 00:22:25,843 --> 00:22:28,259 So I think there have been changes, but for the most part, 554 00:22:28,259 --> 00:22:31,069 it's proven very resilient. 555 00:22:31,069 --> 00:22:32,799 So really quite an interesting paper. 556 00:22:32,799 --> 00:22:36,720 Now, I wanted to go into the attention mechanism. 557 00:22:36,720 --> 00:22:43,092 And I think, the way I interpret it is not similar to the ways 558 00:22:43,092 --> 00:22:44,550 that I've seen it presented before. 559 00:22:44,549 --> 00:22:47,417 So let me try a different way of how I see it. 560 00:22:47,417 --> 00:22:49,959 Basically, to me, attention is kind of like the communication 561 00:22:49,960 --> 00:22:52,210 phase of the transformer, and the transformer 562 00:22:52,210 --> 00:22:55,616 interweaves two phases of the communication phase, which 563 00:22:55,616 --> 00:22:57,700 is the multi-headed attention, and the computation 564 00:22:57,700 --> 00:23:00,069 stage, which is this multilayered perceptron 565 00:23:00,069 --> 00:23:01,539 or [INAUDIBLE]. 566 00:23:01,539 --> 00:23:03,670 So in the communication phase, it's 567 00:23:03,670 --> 00:23:05,170 really just a data dependent message 568 00:23:05,170 --> 00:23:07,279 passing on directed graphs. 569 00:23:07,279 --> 00:23:09,279 And you can think of it as OK, forget everything 570 00:23:09,279 --> 00:23:10,960 with machine translation, everything. 571 00:23:10,960 --> 00:23:13,120 Let's just-- we have directed graphs. 572 00:23:13,119 --> 00:23:16,000 At each node, you are storing a vector. 573 00:23:16,000 --> 00:23:18,714 And then let me talk now about the communication 574 00:23:18,714 --> 00:23:20,589 phase of how these vectors talk to each other 575 00:23:20,589 --> 00:23:21,309 and this directed graph. 576 00:23:21,309 --> 00:23:23,230 And then the compute phase later is just 577 00:23:23,230 --> 00:23:27,700 a multi-perceptron, which then basically acts on every node 578 00:23:27,700 --> 00:23:28,932 individually. 579 00:23:28,932 --> 00:23:30,640 But how do these nodes talk to each other 580 00:23:30,640 --> 00:23:32,930 in this directed graph? 581 00:23:32,930 --> 00:23:36,759 So I wrote like some simple Python-- 582 00:23:36,759 --> 00:23:39,339 I wrote this in Python basically to create 583 00:23:39,339 --> 00:23:44,049 one round of communication of using attention 584 00:23:44,049 --> 00:23:46,549 as the message passing scheme. 585 00:23:46,549 --> 00:23:51,204 So here, a node has this private data vector, 586 00:23:51,204 --> 00:23:53,079 as you can think of it as private information 587 00:23:53,079 --> 00:23:54,069 to this node. 588 00:23:54,069 --> 00:23:57,309 And then it can also emit a key, a query, and a value. 589 00:23:57,309 --> 00:24:00,399 And simply, that's done by linear transformation 590 00:24:00,400 --> 00:24:01,310 from this node. 591 00:24:01,309 --> 00:24:07,220 So the key is what are the things that I am-- 592 00:24:07,220 --> 00:24:07,720 sorry. 593 00:24:07,720 --> 00:24:10,214 The query is what are the things that I'm looking for? 594 00:24:10,214 --> 00:24:12,089 The key is what other the things that I have? 595 00:24:12,089 --> 00:24:15,049 And the value is what are the things that I will communicate? 596 00:24:15,049 --> 00:24:16,849 And so then when you have your graph that's 597 00:24:16,849 --> 00:24:19,254 made up of nodes in some random edges, when you actually 598 00:24:19,255 --> 00:24:21,380 have these nodes communicating, what's happening is 599 00:24:21,380 --> 00:24:23,536 you loop over all the nodes individually 600 00:24:23,536 --> 00:24:27,110 in some random order, and you're at some node, 601 00:24:27,109 --> 00:24:29,240 and you get the query vector q, which 602 00:24:29,240 --> 00:24:32,595 is, I'm a node in some graph, and this 603 00:24:32,595 --> 00:24:33,595 is what I'm looking for. 604 00:24:33,595 --> 00:24:36,011 And so that's just achieved via this linear transformation 605 00:24:36,011 --> 00:24:36,859 here. 606 00:24:36,859 --> 00:24:39,716 And then we look at all the inputs that point to this node, 607 00:24:39,717 --> 00:24:42,050 and then they broadcast what are the things that I have, 608 00:24:42,049 --> 00:24:44,029 which is their keys. 609 00:24:44,029 --> 00:24:45,680 So they broadcast the keys. 610 00:24:45,680 --> 00:24:49,279 I have the query, then those interact by dot product 611 00:24:49,279 --> 00:24:51,210 to get scores. 612 00:24:51,210 --> 00:24:53,120 So basically, simply by doing dot product, 613 00:24:53,119 --> 00:24:55,669 you get some unnormalized weighting 614 00:24:55,670 --> 00:24:59,870 of the interestingness of all of the information in the nodes 615 00:24:59,869 --> 00:25:02,002 that point to me and to the things I'm looking for. 616 00:25:02,002 --> 00:25:03,919 And then when you normalize that with softmax, 617 00:25:03,920 --> 00:25:06,743 so it just sums to 1, you basically just 618 00:25:06,742 --> 00:25:09,409 end up using those scores, which now sum to 1 in our probability 619 00:25:09,410 --> 00:25:13,279 distribution, and you do a weighted sum of the values 620 00:25:13,279 --> 00:25:15,079 to get your update. 621 00:25:15,079 --> 00:25:17,329 So I have a query. 622 00:25:17,329 --> 00:25:21,500 They have keys, dot products to get interestingness or like 623 00:25:21,500 --> 00:25:24,170 affinity, softmax to normalize it, and then 624 00:25:24,170 --> 00:25:27,398 weighted sum of those values flow to me and update me. 625 00:25:27,397 --> 00:25:29,439 And this is happening for each node individually. 626 00:25:29,440 --> 00:25:30,707 And then we update at the end. 627 00:25:30,707 --> 00:25:32,540 And so this kind of a message passing scheme 628 00:25:32,539 --> 00:25:35,990 is at the heart of the transformer. 629 00:25:35,990 --> 00:25:40,204 And it happens in the more vectorized batched way 630 00:25:40,204 --> 00:25:44,210 that is more confusing and is also interspersed with layer 631 00:25:44,210 --> 00:25:46,640 norms and things like that to make the training behave 632 00:25:46,640 --> 00:25:47,473 better. 633 00:25:47,472 --> 00:25:49,639 But that's roughly what's happening in the attention 634 00:25:49,640 --> 00:25:51,140 mechanism, I think, on a high level. 635 00:25:53,720 --> 00:25:59,029 So yeah, so in the communication phase of the transformer, then 636 00:25:59,029 --> 00:26:00,785 this message passing scheme happens 637 00:26:00,785 --> 00:26:06,490 in every head in parallel and then in every layer in series 638 00:26:06,490 --> 00:26:08,529 and with different weights each time. 639 00:26:08,529 --> 00:26:13,149 And that's it as far as the multi-headed attention goes. 640 00:26:13,150 --> 00:26:15,790 And so if you look at these encooder-decoder models, 641 00:26:15,789 --> 00:26:18,042 you can think of it then in terms of the connectivity 642 00:26:18,042 --> 00:26:19,209 of these nodes in the graph. 643 00:26:19,210 --> 00:26:21,920 You can think of it as like, OK, all these tokens that 644 00:26:21,920 --> 00:26:23,920 are in the encoder that we want to condition on, 645 00:26:23,920 --> 00:26:25,600 they are fully connected to each other. 646 00:26:25,599 --> 00:26:28,329 So when they communicate, they communicate fully 647 00:26:28,329 --> 00:26:30,589 when you calculate their features. 648 00:26:30,589 --> 00:26:32,139 But in the decoder, because we are 649 00:26:32,140 --> 00:26:33,627 trying to have a language model, we 650 00:26:33,626 --> 00:26:35,710 don't want to have communication for future tokens 651 00:26:35,710 --> 00:26:38,170 because they give away the answer at this step. 652 00:26:38,170 --> 00:26:40,810 So the tokens in the decoder are fully connected 653 00:26:40,809 --> 00:26:43,644 from all the encoder states, and then they 654 00:26:43,644 --> 00:26:46,575 are also fully connected from everything that is decoding. 655 00:26:46,575 --> 00:26:49,150 And so you end up with this triangular structure 656 00:26:49,150 --> 00:26:50,560 in the data graph. 657 00:26:50,559 --> 00:26:52,359 But that's the message passing scheme 658 00:26:52,359 --> 00:26:54,814 that this basically implements. 659 00:26:54,815 --> 00:26:57,190 And then you have to be also a little bit careful because 660 00:26:57,190 --> 00:26:59,065 in the cross attention here with the decoder, 661 00:26:59,065 --> 00:27:01,620 you consume the features from the top of the encoder. 662 00:27:01,619 --> 00:27:03,952 So think of it as in the encoder, 663 00:27:03,952 --> 00:27:05,619 all the nodes are looking at each other, 664 00:27:05,619 --> 00:27:08,319 all the tokens are looking at each other many, many times. 665 00:27:08,319 --> 00:27:09,759 And they really figure out what's in there, 666 00:27:09,759 --> 00:27:12,301 and then the decoder when it's looking only at the top nodes. 667 00:27:14,875 --> 00:27:16,750 So that's roughly the message passing scheme. 668 00:27:16,750 --> 00:27:18,750 I was going to go into more of an implementation 669 00:27:18,750 --> 00:27:19,660 of a transformer. 670 00:27:19,660 --> 00:27:23,125 I don't know if there's any questions about this. 671 00:27:23,125 --> 00:27:26,434 [INAUDIBLE] self-attention and multi-headed attention, 672 00:27:26,434 --> 00:27:30,419 but what is the advantage of [INAUDIBLE]?? 673 00:27:30,420 --> 00:27:35,370 Yeah, so self-attention and multi-headed attention, so 674 00:27:35,369 --> 00:27:38,000 the multi-headed attention is just this attention scheme, 675 00:27:38,000 --> 00:27:40,717 but it's just applied multiple times in parallel. 676 00:27:40,717 --> 00:27:42,800 Multiple heads just means independent applications 677 00:27:42,799 --> 00:27:44,970 of the same attention. 678 00:27:44,970 --> 00:27:47,990 So this message passing scheme basically just 679 00:27:47,990 --> 00:27:49,970 happens in parallel multiple times 680 00:27:49,970 --> 00:27:52,940 with different weights for the query, key, and value. 681 00:27:52,940 --> 00:27:55,130 So you can almost look at it like in parallel, I'm 682 00:27:55,130 --> 00:27:57,422 looking for, I'm seeking different kinds of information 683 00:27:57,422 --> 00:27:59,029 from different nodes. 684 00:27:59,029 --> 00:28:01,024 And I'm collecting it all in the same node. 685 00:28:01,025 --> 00:28:03,390 It's all done in parallel. 686 00:28:03,390 --> 00:28:06,980 So heads is really just copy-paste in parallel. 687 00:28:06,980 --> 00:28:12,682 And layers are copy-paste but in series. 688 00:28:12,682 --> 00:28:15,940 Maybe that makes sense. 689 00:28:15,940 --> 00:28:18,610 And self-attention, when it's self-attention, 690 00:28:18,609 --> 00:28:21,699 what it's referring to is that the node here 691 00:28:21,700 --> 00:28:23,055 produces each node here. 692 00:28:23,055 --> 00:28:25,632 So as I described it here, this is really self-attention 693 00:28:25,632 --> 00:28:27,340 because every one of these nodes produces 694 00:28:27,339 --> 00:28:30,429 a key query and a value from this individual node. 695 00:28:30,430 --> 00:28:33,850 When you have cross-attention, you have one cross-attention 696 00:28:33,849 --> 00:28:36,929 here, coming from the encoder. 697 00:28:36,930 --> 00:28:38,680 That just means that the queries are still 698 00:28:38,680 --> 00:28:42,400 produced from this node, but the keys and the values 699 00:28:42,400 --> 00:28:44,920 are produced as a function of nodes that 700 00:28:44,920 --> 00:28:48,130 are coming from the encoder. 701 00:28:48,130 --> 00:28:52,050 So I have my queries because I'm trying to decode some-- 702 00:28:52,049 --> 00:28:53,932 the fifth word in the sequence. 703 00:28:53,932 --> 00:28:55,349 And I'm looking for certain things 704 00:28:55,349 --> 00:28:56,759 because I'm the fifth word. 705 00:28:56,759 --> 00:28:58,769 And then the keys and the values in terms 706 00:28:58,769 --> 00:29:01,349 of the source of information that could answer my queries 707 00:29:01,349 --> 00:29:04,019 can come from the previous nodes in the current decoding 708 00:29:04,019 --> 00:29:06,670 sequence or from the top of the encoder. 709 00:29:06,670 --> 00:29:09,240 So all the nodes that have already seen all 710 00:29:09,240 --> 00:29:12,120 of the encoding tokens many, many times cannot broadcast 711 00:29:12,119 --> 00:29:14,319 what they contain in terms of information. 712 00:29:14,319 --> 00:29:18,652 So I guess, to summarize, the self-attention is-- 713 00:29:18,652 --> 00:29:20,360 sorry, cross-attention and self-attention 714 00:29:20,359 --> 00:29:24,199 only differ in where the piece and the values come from. 715 00:29:24,200 --> 00:29:28,130 Either the keys and values are produced from this node, 716 00:29:28,130 --> 00:29:31,340 or they are produced from some external source like an encoder 717 00:29:31,339 --> 00:29:33,199 and the nodes over there. 718 00:29:33,200 --> 00:29:39,000 But algorithmically, is the same mathematical operations. 719 00:29:39,000 --> 00:29:39,961 Question. 720 00:29:39,961 --> 00:29:40,599 Yeah, OK. 721 00:29:40,599 --> 00:29:41,899 So two questions for you. 722 00:29:41,900 --> 00:29:48,690 First question is, in the message passing [INAUDIBLE] 723 00:29:56,690 --> 00:30:00,799 So think of-- so each one of these nodes is a token. 724 00:30:04,067 --> 00:30:06,109 I guess they don't have a very good picture of it 725 00:30:06,109 --> 00:30:06,901 in the transformer. 726 00:30:06,902 --> 00:30:14,930 But this node here could represent the third word 727 00:30:14,930 --> 00:30:19,505 in the output in the decoder, and in the beginning, 728 00:30:19,505 --> 00:30:21,290 it is just the embedding of the word. 729 00:30:27,119 --> 00:30:30,669 And then, OK, I have to think through this analogy 730 00:30:30,670 --> 00:30:31,420 a little bit more. 731 00:30:31,420 --> 00:30:32,711 I came up with it this morning. 732 00:30:32,711 --> 00:30:34,400 [LAUGHTER] 733 00:30:34,400 --> 00:30:35,830 [INAUDIBLE] 734 00:30:39,940 --> 00:30:45,865 What example of instantiation [INAUDIBLE] nodes 735 00:30:45,865 --> 00:30:50,299 as in in blocks were embedding? 736 00:30:50,299 --> 00:30:53,201 These nodes are basically the vectors. 737 00:30:53,201 --> 00:30:54,410 I'll go to an implementation. 738 00:30:54,410 --> 00:30:56,493 I'll go to the implementation, and then maybe I'll 739 00:30:56,492 --> 00:30:58,779 make the connections to the graph. 740 00:30:58,779 --> 00:31:01,490 So let me try to first go to-- let me now go to, 741 00:31:01,490 --> 00:31:03,259 with this intuition in mind, at least, 742 00:31:03,259 --> 00:31:05,259 to a nanoGPT, which is a concrete implementation 743 00:31:05,259 --> 00:31:06,980 of a transformer that is very minimal. 744 00:31:06,980 --> 00:31:08,839 So I worked on this over the last few days, 745 00:31:08,839 --> 00:31:11,737 and here it is reproducing GPT-2 on open web text. 746 00:31:11,738 --> 00:31:14,029 So it's a pretty serious implementation that reproduces 747 00:31:14,029 --> 00:31:17,869 GPT-2, I would say, and provide it enough compute-- 748 00:31:17,869 --> 00:31:21,211 This was one node of 8 GPUs for 38 hours or something 749 00:31:21,211 --> 00:31:22,670 like that, if I remember correctly. 750 00:31:22,670 --> 00:31:23,910 And it's very readable. 751 00:31:23,910 --> 00:31:27,170 It's 300 lines, so everyone can take a look at it. 752 00:31:27,170 --> 00:31:30,622 And yeah, let me basically briefly step through it. 753 00:31:30,622 --> 00:31:34,077 So let's try to have a decoder-only transformer. 754 00:31:34,077 --> 00:31:36,119 So what that means is that it's a language model. 755 00:31:36,119 --> 00:31:39,936 It tries to model the next word in the sequence 756 00:31:39,936 --> 00:31:41,519 or the next character in the sequence. 757 00:31:41,519 --> 00:31:43,079 So the data that we train on this 758 00:31:43,079 --> 00:31:44,309 is always some kind of text. 759 00:31:44,309 --> 00:31:45,856 So here's some fake Shakespeare. 760 00:31:45,856 --> 00:31:47,190 Sorry, this is real Shakespeare. 761 00:31:47,190 --> 00:31:48,600 We're going to produce fake Shakespeare. 762 00:31:48,599 --> 00:31:50,099 So this is called a Tiny Shakespeare 763 00:31:50,099 --> 00:31:52,346 dataset, which is one of my favorite toy datasets. 764 00:31:52,346 --> 00:31:54,180 You take all of Shakespeare, concatenate it, 765 00:31:54,180 --> 00:31:55,650 and it's 1 megabyte file, and then 766 00:31:55,650 --> 00:31:56,850 you can train language models on it 767 00:31:56,849 --> 00:31:58,439 and get infinite Shakespeare, if you like, 768 00:31:58,440 --> 00:31:59,690 which I think is kind of cool. 769 00:31:59,690 --> 00:32:00,761 So we have a text. 770 00:32:00,761 --> 00:32:02,220 The first thing we need to do is we 771 00:32:02,220 --> 00:32:05,160 need to convert it to a sequence of integers 772 00:32:05,160 --> 00:32:09,120 because transformers natively process-- 773 00:32:09,119 --> 00:32:10,661 you can't plug text into transformer. 774 00:32:10,662 --> 00:32:11,912 You need to somehow encode it. 775 00:32:11,912 --> 00:32:13,380 So the way that encoding is done is 776 00:32:13,380 --> 00:32:15,390 we convert, for example, in the simplest case, 777 00:32:15,390 --> 00:32:18,810 every character gets an integer, and then instead of "hi 778 00:32:18,809 --> 00:32:21,799 there," we would have this sequence of integers. 779 00:32:21,799 --> 00:32:25,490 So then you can encode every single character as an integer 780 00:32:25,490 --> 00:32:27,529 and get a massive sequence of integers. 781 00:32:27,529 --> 00:32:29,089 You just concatenate it all into one 782 00:32:29,089 --> 00:32:31,419 large, long one-dimensional sequence. 783 00:32:31,420 --> 00:32:32,750 And then you can train on it. 784 00:32:32,750 --> 00:32:34,563 Now, here, we only have a single document. 785 00:32:34,563 --> 00:32:36,980 In some cases, if you have multiple independent documents, 786 00:32:36,980 --> 00:32:38,914 what people like to do is create special tokens, 787 00:32:38,914 --> 00:32:40,414 and they intersperse those documents 788 00:32:40,414 --> 00:32:42,500 with those special end of text tokens 789 00:32:42,500 --> 00:32:46,160 that they splice in between to create boundaries. 790 00:32:46,160 --> 00:32:50,860 But those boundaries actually don't have any modeling impact. 791 00:32:50,859 --> 00:32:52,609 It's just that the transformer is supposed 792 00:32:52,609 --> 00:32:55,849 to learn via backpropagation that the end of document 793 00:32:55,849 --> 00:33:00,019 sequence means that you should wipe the memory. 794 00:33:00,019 --> 00:33:02,000 OK, so then we produce batches. 795 00:33:02,000 --> 00:33:04,339 So these batches of data just mean 796 00:33:04,339 --> 00:33:06,379 that we go back to the one-dimensional sequence, 797 00:33:06,380 --> 00:33:08,780 and we take out chunks of this sequence. 798 00:33:08,779 --> 00:33:13,774 So say, if the block size is 8, Then the block size indicates 799 00:33:13,775 --> 00:33:17,750 the maximum length of context that your transformer will 800 00:33:17,750 --> 00:33:18,295 process. 801 00:33:18,295 --> 00:33:20,509 So if our block size is 8, that means 802 00:33:20,509 --> 00:33:23,720 that we are going to have up to eight characters of context 803 00:33:23,720 --> 00:33:26,630 to predict the ninth character in a sequence. 804 00:33:26,630 --> 00:33:29,120 And the batch size indicates how many sequences in parallel 805 00:33:29,119 --> 00:33:30,119 we're going to process. 806 00:33:30,119 --> 00:33:31,879 And we want this to be as large as possible, 807 00:33:31,880 --> 00:33:33,650 so we're fully taking advantage of the GPU 808 00:33:33,650 --> 00:33:36,540 and the parallels [INAUDIBLE] So in this example, 809 00:33:36,539 --> 00:33:38,000 we're doing a 4 by 8 batches. 810 00:33:38,000 --> 00:33:41,390 So every row here is independent example 811 00:33:41,390 --> 00:33:47,412 and then every row here is a small chunk of the sequence 812 00:33:47,412 --> 00:33:48,620 that we're going to train on. 813 00:33:48,619 --> 00:33:50,619 And then we have both the inputs and the targets 814 00:33:50,619 --> 00:33:52,579 at every single point here. 815 00:33:52,579 --> 00:33:55,159 So to fully spell out what's contained in a single 4 816 00:33:55,160 --> 00:33:57,320 by 8 batch to the transformer-- 817 00:33:57,319 --> 00:33:59,109 I sort of compact it here-- 818 00:33:59,109 --> 00:34:04,669 so when the input is 47, by itself, the target is 58. 819 00:34:04,670 --> 00:34:07,279 And when the input is the sequence 47, 58, 820 00:34:07,279 --> 00:34:08,929 the target is one. 821 00:34:08,929 --> 00:34:13,019 And when it's 47, 58, 1, the target is 51 and so on. 822 00:34:13,019 --> 00:34:15,679 So actually, the single batch of examples that score by 8 823 00:34:15,679 --> 00:34:17,490 actually has a ton of individual examples 824 00:34:17,490 --> 00:34:18,949 that we are expecting a transformer 825 00:34:18,949 --> 00:34:21,863 to learn on in parallel. 826 00:34:21,862 --> 00:34:23,779 And so you'll see that the batches are learned 827 00:34:23,780 --> 00:34:28,459 on completely independently, but the time dimension here along 828 00:34:28,458 --> 00:34:30,948 horizontally is also trained on in parallel. 829 00:34:30,949 --> 00:34:34,309 So your real batch size is more like B times T. 830 00:34:34,309 --> 00:34:37,340 And it's just that the context grows linearly 831 00:34:37,340 --> 00:34:41,329 for the predictions that you make along the T direction 832 00:34:41,329 --> 00:34:42,509 in the model. 833 00:34:42,510 --> 00:34:45,664 So this is all the examples that the model will learn from, 834 00:34:45,664 --> 00:34:48,830 this single batch. 835 00:34:48,829 --> 00:34:52,768 So now, this is the GPT class. 836 00:34:52,768 --> 00:34:55,946 And because this is a decoder-only model, 837 00:34:55,947 --> 00:34:58,280 so we're not going to have an encoder because there's no 838 00:34:58,280 --> 00:34:59,952 English we're translating from-- 839 00:34:59,952 --> 00:35:02,119 we're not trying to condition in some other external 840 00:35:02,119 --> 00:35:02,779 information. 841 00:35:02,780 --> 00:35:05,510 We're just trying to produce a sequence of words that 842 00:35:05,510 --> 00:35:08,090 follow each other or likely to. 843 00:35:08,090 --> 00:35:10,658 So this is all PyTorch, and I'm going slightly faster 844 00:35:10,657 --> 00:35:12,949 because I'm assuming people have taken 231 or something 845 00:35:12,949 --> 00:35:15,210 along those lines. 846 00:35:15,210 --> 00:35:19,190 But here in the forward pass, we take these indices, 847 00:35:19,190 --> 00:35:24,500 and then we both encode the identity of the indices, 848 00:35:24,500 --> 00:35:26,789 just via an embedding lookup table. 849 00:35:26,789 --> 00:35:31,190 So every single integer, we index into a lookup table of 850 00:35:31,190 --> 00:35:34,460 vectors in this, and end up embedding, and pull out 851 00:35:34,460 --> 00:35:38,099 the word vector for that token. 852 00:35:38,099 --> 00:35:41,431 And then because the transformer by itself 853 00:35:41,431 --> 00:35:43,389 doesn't actually-- the process is set natively. 854 00:35:43,389 --> 00:35:45,742 So we need to also positionally encode these vectors 855 00:35:45,742 --> 00:35:47,659 so that we basically have both the information 856 00:35:47,659 --> 00:35:51,679 about the token identity and its place in the sequence from 1 857 00:35:51,679 --> 00:35:53,869 to block size. 858 00:35:53,869 --> 00:35:56,659 Now, the information about what and where 859 00:35:56,659 --> 00:35:58,879 is combined additively, so the token embeddings 860 00:35:58,880 --> 00:36:02,750 and the positional embeddings are just added exactly as here. 861 00:36:02,750 --> 00:36:06,800 So then there's optional dropout, 862 00:36:06,800 --> 00:36:08,780 this x here basically just contains 863 00:36:08,780 --> 00:36:14,870 the set of words and their positions, 864 00:36:14,869 --> 00:36:16,786 and that feeds into the blocks of transformer. 865 00:36:16,786 --> 00:36:18,744 And we're going to look into what's block here. 866 00:36:18,744 --> 00:36:20,599 But for here, for now, this is just a series 867 00:36:20,599 --> 00:36:22,239 of blocks in a transformer. 868 00:36:22,239 --> 00:36:23,989 And then in the end, there's a layer norm, 869 00:36:23,989 --> 00:36:26,799 and then you're decoding the logits 870 00:36:26,800 --> 00:36:30,680 for the next word or next integer in a sequence, 871 00:36:30,679 --> 00:36:33,469 using the linear projection of the output of this transformer 872 00:36:33,469 --> 00:36:36,859 So LM head here, a short core language model head. 873 00:36:36,860 --> 00:36:38,945 It's just a linear function. 874 00:36:38,945 --> 00:36:42,710 So basically, positionally encode all the words, 875 00:36:42,710 --> 00:36:45,230 feed them into a sequence of blocks, 876 00:36:45,230 --> 00:36:47,690 and then apply a linear layer to get the probability 877 00:36:47,690 --> 00:36:50,336 distribution for the next character. 878 00:36:50,336 --> 00:36:51,920 And then if we have the targets, which 879 00:36:51,920 --> 00:36:54,057 we produced in the data order-- 880 00:36:54,057 --> 00:36:55,849 and you'll notice that the targets are just 881 00:36:55,849 --> 00:36:59,297 the inputs offset by one in time-- 882 00:36:59,297 --> 00:37:01,380 then those targets feed into a cross entropy loss. 883 00:37:01,380 --> 00:37:03,088 So this is just a negative log likelihood 884 00:37:03,088 --> 00:37:04,705 typical classification loss. 885 00:37:04,704 --> 00:37:08,840 So now let's drill into what's here in the blocks. 886 00:37:08,840 --> 00:37:11,470 So these blocks that are applied sequentially, 887 00:37:11,469 --> 00:37:13,469 there's, again, as I mentioned, this communicate 888 00:37:13,469 --> 00:37:15,000 phase and the compute phase. 889 00:37:15,000 --> 00:37:17,135 So in the communicate phase, all the nodes 890 00:37:17,135 --> 00:37:21,260 get to talk to each other, and so these nodes are basically, 891 00:37:21,260 --> 00:37:23,900 if our block size is 8, then we are 892 00:37:23,900 --> 00:37:26,405 going to have eight nodes in this graph. 893 00:37:26,405 --> 00:37:28,010 There's eight nodes in this graph. 894 00:37:28,010 --> 00:37:30,250 The first node is pointed to only by itself. 895 00:37:30,250 --> 00:37:33,324 The second node is pointed to by the first node and itself. 896 00:37:33,324 --> 00:37:35,449 The third node is pointed to by the first two nodes 897 00:37:35,449 --> 00:37:36,639 and itself, et cetera. 898 00:37:36,639 --> 00:37:38,940 So there's eight nodes here. 899 00:37:38,940 --> 00:37:42,472 So you apply-- there's a residual pathway and x. 900 00:37:42,472 --> 00:37:43,139 You take it out. 901 00:37:43,139 --> 00:37:45,449 You apply a layer norm, and then the self-attention 902 00:37:45,449 --> 00:37:47,879 so that these communicate, these eight nodes communicate. 903 00:37:47,880 --> 00:37:50,220 But you have to keep in mind that the batch is 4. 904 00:37:50,219 --> 00:37:54,179 So because batch is 4, this is also applied-- 905 00:37:54,179 --> 00:37:55,859 so we have eight nodes communicating, 906 00:37:55,860 --> 00:37:58,443 but there's a batch of four of them individually communicating 907 00:37:58,443 --> 00:37:59,880 in one of those eight nodes. 908 00:37:59,880 --> 00:38:02,380 There's no crisscross across the batch dimension, of course. 909 00:38:02,380 --> 00:38:04,680 There's no batch anywhere luckily. 910 00:38:04,679 --> 00:38:06,809 And then once they've changed information, 911 00:38:06,809 --> 00:38:09,630 they are processed using the multi-layer perceptron. 912 00:38:09,630 --> 00:38:12,630 And that's the compute phase. 913 00:38:12,630 --> 00:38:18,137 And then also here we are missing the cross-attention 914 00:38:18,137 --> 00:38:19,679 because this is a decoder-only model. 915 00:38:19,679 --> 00:38:21,277 So all we have is this step here, 916 00:38:21,277 --> 00:38:22,860 the multi-headed attention, and that's 917 00:38:22,860 --> 00:38:24,579 this line, the communicate phase. 918 00:38:24,579 --> 00:38:27,119 And then we have the feed forward, which is the MLP, 919 00:38:27,119 --> 00:38:29,710 and that's the compute phase. 920 00:38:29,710 --> 00:38:31,610 I'll take question's a bit later. 921 00:38:31,610 --> 00:38:34,745 Then the MLP here is fairly straightforward. 922 00:38:34,744 --> 00:38:38,069 The MLP is just individual processing on each node, 923 00:38:38,070 --> 00:38:41,530 just transforming the feature representation at that node. 924 00:38:41,530 --> 00:38:45,120 So applying a two-layer neural net 925 00:38:45,119 --> 00:38:47,204 with a GELU nonlinearity, which is just 926 00:38:47,204 --> 00:38:49,079 think of it as a ReLU or something like that. 927 00:38:49,079 --> 00:38:51,400 It's just a nonlinearity. 928 00:38:51,400 --> 00:38:53,610 And then MLP is straightforward. 929 00:38:53,610 --> 00:38:55,760 I don't think there's anything too crazy there. 930 00:38:55,760 --> 00:38:57,760 And then this is the causal self-attention part, 931 00:38:57,760 --> 00:38:59,750 the communication phase. 932 00:38:59,750 --> 00:39:01,539 So this is like the meat of things 933 00:39:01,539 --> 00:39:03,670 and the most complicated part. 934 00:39:03,670 --> 00:39:06,599 It's only complicated because of the batching 935 00:39:06,599 --> 00:39:10,349 and the implementation detail of how you mask the connectivity 936 00:39:10,349 --> 00:39:13,619 in the graph so that you can't obtain 937 00:39:13,619 --> 00:39:15,119 any information from the future when 938 00:39:15,119 --> 00:39:16,327 you're predicting your token. 939 00:39:16,327 --> 00:39:18,429 Otherwise, it gives away the information. 940 00:39:18,429 --> 00:39:23,099 So if I'm the fifth token and if I'm the fifth position, 941 00:39:23,099 --> 00:39:26,279 then I'm getting the fourth token coming into the input, 942 00:39:26,280 --> 00:39:29,010 and I'm attending to the third, second, and first, 943 00:39:29,010 --> 00:39:32,160 and I'm trying to figure out what is the next token. 944 00:39:32,159 --> 00:39:34,589 Well then, in this batch, in the next element 945 00:39:34,590 --> 00:39:37,050 over in the time dimension, the answer is at the input. 946 00:39:37,050 --> 00:39:40,360 So I can't get any information from there. 947 00:39:40,360 --> 00:39:41,860 So that's why this is all tricky, 948 00:39:41,860 --> 00:39:45,070 but basically, in the forward pass, 949 00:39:45,070 --> 00:39:50,658 we are calculating the queries, keys, and values based on x. 950 00:39:50,657 --> 00:39:52,449 So these are the keys, queries, and values. 951 00:39:52,449 --> 00:39:54,444 Here, when I'm computing the attention, 952 00:39:54,445 --> 00:39:58,019 I have the queries matrix multiplying the piece. 953 00:39:58,019 --> 00:40:00,730 So this is the dot product in parallel for all the queries 954 00:40:00,730 --> 00:40:03,400 and all the keys in all the heads. 955 00:40:03,400 --> 00:40:06,160 So I failed to mention that there's also 956 00:40:06,159 --> 00:40:08,679 the aspect of the heads, which is also done all in parallel 957 00:40:08,679 --> 00:40:09,029 here. 958 00:40:09,030 --> 00:40:10,900 So we have the batch dimension, the time dimension, 959 00:40:10,900 --> 00:40:12,369 and the head dimension, and you end up 960 00:40:12,369 --> 00:40:14,779 with five-dimensional tensors, and it's all really confusing. 961 00:40:14,780 --> 00:40:17,110 So I invite you to step through it later and convince yourself 962 00:40:17,110 --> 00:40:19,059 that this is actually doing the right thing. 963 00:40:19,059 --> 00:40:21,549 But basically, you have the batch dimension, the head 964 00:40:21,550 --> 00:40:23,560 dimension and the time dimension, 965 00:40:23,559 --> 00:40:25,250 and then you have features at them. 966 00:40:25,250 --> 00:40:28,630 And so this is evaluating for all the batch elements, for all 967 00:40:28,630 --> 00:40:31,300 the head elements, and all the time elements, 968 00:40:31,300 --> 00:40:34,030 the simple Python that I gave you earlier, which is query 969 00:40:34,030 --> 00:40:35,769 dot product p. 970 00:40:35,769 --> 00:40:38,949 Then here, we do a masked_fill, and what this is doing 971 00:40:38,949 --> 00:40:44,259 is it's basically clamping the attention between the nodes 972 00:40:44,260 --> 00:40:46,480 that are not supposed to communicate to be negative 973 00:40:46,480 --> 00:40:47,110 infinity. 974 00:40:47,110 --> 00:40:48,485 And we're doing negative infinity 975 00:40:48,485 --> 00:40:51,220 because we're about to softmax, and so negative infinity will 976 00:40:51,219 --> 00:40:54,384 make basically the attention that those elements be zero. 977 00:40:54,385 --> 00:40:56,590 And so here we are going to basically end up 978 00:40:56,590 --> 00:41:03,370 with the weights, the affinities between these nodes, optional 979 00:41:03,369 --> 00:41:03,880 dropout. 980 00:41:03,880 --> 00:41:08,460 And then here, attention matrix multiply v is basically 981 00:41:08,460 --> 00:41:10,960 the gathering of the information according to the affinities 982 00:41:10,960 --> 00:41:11,829 we calculated. 983 00:41:11,829 --> 00:41:14,529 And this is just a weighted sum of the values 984 00:41:14,530 --> 00:41:15,769 at all those nodes. 985 00:41:15,769 --> 00:41:19,030 So this matrix multiplies is doing that weighted sum. 986 00:41:19,030 --> 00:41:20,993 And then transpose contiguous view 987 00:41:20,992 --> 00:41:22,659 because it's all complicated and batched 988 00:41:22,659 --> 00:41:24,789 in five-dimensional tensors, but it's really not 989 00:41:24,789 --> 00:41:26,889 doing anything, optional drop out, 990 00:41:26,889 --> 00:41:30,679 and then a linear projection back to the residual pathway. 991 00:41:30,679 --> 00:41:34,710 So this is implementing the communication phase here. 992 00:41:34,710 --> 00:41:37,869 Then you can train this transformer. 993 00:41:37,869 --> 00:41:41,170 And then you can generate infinite Shakespeare. 994 00:41:41,170 --> 00:41:43,090 And you will simply do this by-- 995 00:41:43,090 --> 00:41:47,170 because our block size is 8, we start with a sum token, 996 00:41:47,170 --> 00:41:50,500 say like, I used in this case, you 997 00:41:50,500 --> 00:41:53,050 can use something like a new line as the start token. 998 00:41:53,050 --> 00:41:55,517 And then you communicate only to yourself 999 00:41:55,516 --> 00:41:57,099 because there's a single node, and you 1000 00:41:57,099 --> 00:41:59,559 get the probability distribution for the first word 1001 00:41:59,559 --> 00:42:00,650 in the sequence. 1002 00:42:00,650 --> 00:42:03,603 And then you decode it for the first character 1003 00:42:03,603 --> 00:42:04,269 in the sequence. 1004 00:42:04,269 --> 00:42:05,559 You decode the character. 1005 00:42:05,559 --> 00:42:06,549 And then you bring back the character, 1006 00:42:06,550 --> 00:42:08,019 and you re-encode it as an integer. 1007 00:42:08,019 --> 00:42:10,605 And now, you have the second thing. 1008 00:42:10,605 --> 00:42:12,760 And so you get-- 1009 00:42:12,760 --> 00:42:14,470 OK, we're at the first position, and this 1010 00:42:14,469 --> 00:42:17,659 is whatever integer it is, add the positional encodings, 1011 00:42:17,659 --> 00:42:19,659 goes into the sequence, goes in the transformer, 1012 00:42:19,659 --> 00:42:21,940 and again, this token now communicates 1013 00:42:21,940 --> 00:42:26,690 with the first token and it's identity. 1014 00:42:26,690 --> 00:42:28,389 And so you just keep plugging it back. 1015 00:42:28,389 --> 00:42:31,000 And once you run out of the block size, which is eight, 1016 00:42:31,000 --> 00:42:33,130 you start to crawl, because you can never 1017 00:42:33,130 --> 00:42:34,660 have watt size more than eight in the way you've 1018 00:42:34,659 --> 00:42:35,701 trained this transformer. 1019 00:42:35,702 --> 00:42:37,690 So we have more and more context until eight. 1020 00:42:37,690 --> 00:42:39,190 And then if you want to generate beyond eight, 1021 00:42:39,190 --> 00:42:41,481 you have to start cropping because the transformer only 1022 00:42:41,481 --> 00:42:43,690 works for eight elements in time dimension. 1023 00:42:43,690 --> 00:42:47,170 And so all of these transformers in the [INAUDIBLE] setting 1024 00:42:47,170 --> 00:42:50,590 have a finite block size or context length, 1025 00:42:50,590 --> 00:42:54,460 and in typical models, this will be 1,024 tokens or 2,048 1026 00:42:54,460 --> 00:42:56,349 tokens, something like that. 1027 00:42:56,349 --> 00:42:58,559 But these tokens are usually like BPE tokens, 1028 00:42:58,559 --> 00:43:00,434 or SentencePiece tokens, or WorkPiece tokens. 1029 00:43:00,434 --> 00:43:02,539 There's many different encodings. 1030 00:43:02,539 --> 00:43:03,860 So it's not like that long. 1031 00:43:03,860 --> 00:43:05,349 And so that's why, I think, [INAUDIBLE].. 1032 00:43:05,349 --> 00:43:06,789 We really want to expand the context size, 1033 00:43:06,789 --> 00:43:08,469 and it gets gnarly because the attention 1034 00:43:08,469 --> 00:43:11,659 is sporadic in the [INAUDIBLE] case. 1035 00:43:11,659 --> 00:43:16,759 Now, if you want to implement an encoder instead of a decoder 1036 00:43:16,760 --> 00:43:18,680 attention. 1037 00:43:18,679 --> 00:43:21,214 Then all you have to do is this [INAUDIBLE] 1038 00:43:21,215 --> 00:43:23,340 and you just delete that line. 1039 00:43:23,340 --> 00:43:25,414 So if you don't mask the attention, 1040 00:43:25,414 --> 00:43:27,289 then all the nodes communicate to each other, 1041 00:43:27,289 --> 00:43:29,389 and everything is allowed, and information 1042 00:43:29,389 --> 00:43:31,129 flows between all the nodes. 1043 00:43:31,130 --> 00:43:35,750 So if you want to have the encoder here, just delete. 1044 00:43:35,750 --> 00:43:38,030 All the encoder blocks will use attention 1045 00:43:38,030 --> 00:43:39,380 where this line is deleted. 1046 00:43:39,380 --> 00:43:40,730 That's it. 1047 00:43:40,730 --> 00:43:44,480 So you're allowing whatever-- this encoder might store say, 1048 00:43:44,480 --> 00:43:46,880 10 tokens, 10 nodes, and they are all 1049 00:43:46,880 --> 00:43:51,240 allowed to communicate to each other going up the transformer. 1050 00:43:51,239 --> 00:43:53,369 And then if you want to implement cross-attention, 1051 00:43:53,369 --> 00:43:55,327 so you have a full encoder-decoder transformer, 1052 00:43:55,327 --> 00:43:59,329 not just a decoder-only transformer or a GPT. 1053 00:43:59,329 --> 00:44:03,159 Then we need to also add cross-attention in the middle. 1054 00:44:03,159 --> 00:44:05,809 So here, there is a self-attention piece where all 1055 00:44:05,809 --> 00:44:06,469 the-- 1056 00:44:06,469 --> 00:44:08,802 there's a self-attention piece, a cross-attention piece, 1057 00:44:08,802 --> 00:44:09,980 and this MLP. 1058 00:44:09,980 --> 00:44:12,320 And in the cross-attention, we need 1059 00:44:12,320 --> 00:44:14,570 to take the features from the top of the encoder. 1060 00:44:14,570 --> 00:44:16,789 We need to add one more line here, 1061 00:44:16,789 --> 00:44:20,090 and this would be the cross-attention instead of a-- 1062 00:44:20,090 --> 00:44:22,340 I should have implemented it instead of just pointing, 1063 00:44:22,340 --> 00:44:23,300 I think. 1064 00:44:23,300 --> 00:44:25,310 But there will be a cross-attention line here. 1065 00:44:25,309 --> 00:44:26,929 So we'll have three lines because we 1066 00:44:26,929 --> 00:44:28,190 need to add another block. 1067 00:44:28,190 --> 00:44:31,400 And the queries will come from x but the keys 1068 00:44:31,400 --> 00:44:35,043 and the values will come from the top of the encoder. 1069 00:44:35,043 --> 00:44:36,710 And there will be basic code information 1070 00:44:36,710 --> 00:44:38,126 flowing from the encoder, strictly 1071 00:44:38,126 --> 00:44:41,420 to all the nodes inside x. 1072 00:44:41,420 --> 00:44:42,750 And then that's it. 1073 00:44:42,750 --> 00:44:44,255 So it's a very simple modifications 1074 00:44:44,255 --> 00:44:47,369 on the decoder attention. 1075 00:44:47,369 --> 00:44:49,469 So you'll hear people talk that you have 1076 00:44:49,469 --> 00:44:51,884 a decoder-only model like GPT. 1077 00:44:51,885 --> 00:44:53,760 You can have an encoder-only model like BERT, 1078 00:44:53,760 --> 00:44:55,427 or you can have an encoder-decoder model 1079 00:44:55,427 --> 00:44:59,660 like say T5, doing things like machine translation. 1080 00:44:59,659 --> 00:45:04,143 And in BERT, you can't train it using this language modeling 1081 00:45:04,143 --> 00:45:06,059 setup that's utter aggressive, and you're just 1082 00:45:06,059 --> 00:45:07,340 trying to predict next [INAUDIBLE] in the sequence. 1083 00:45:07,340 --> 00:45:09,720 You're training it doing slightly different objectives. 1084 00:45:09,719 --> 00:45:12,000 You're putting in the full sentence, 1085 00:45:12,000 --> 00:45:14,454 and, the full sentence is allowed to communicate fully. 1086 00:45:14,454 --> 00:45:16,829 And then you're trying to classify sentiment or something 1087 00:45:16,829 --> 00:45:18,039 like that. 1088 00:45:18,039 --> 00:45:21,489 So you're not trying to model the next token in the sequence. 1089 00:45:21,489 --> 00:45:26,649 So these are trained slightly different 1090 00:45:26,650 --> 00:45:31,789 using masking and other denoising techniques. 1091 00:45:31,789 --> 00:45:32,289 OK. 1092 00:45:32,289 --> 00:45:34,570 So that's like the transformer. 1093 00:45:34,570 --> 00:45:36,410 I'm going to continue. 1094 00:45:36,409 --> 00:45:38,565 So yeah, maybe more questions. 1095 00:45:38,565 --> 00:45:49,349 [INAUDIBLE] 1096 00:46:01,710 --> 00:46:06,030 This is like we are enforcing these constraints on it 1097 00:46:06,030 --> 00:46:12,610 by just masking [INAUDIBLE] 1098 00:46:12,610 --> 00:46:14,039 So I'm not sure if I fully follow. 1099 00:46:14,039 --> 00:46:16,769 So there's different ways to look at this analogy, 1100 00:46:16,769 --> 00:46:18,329 but one analogy is you can interpret 1101 00:46:18,329 --> 00:46:20,199 this graph as really fixed. 1102 00:46:20,199 --> 00:46:22,230 It's just that every time we do the communicate, 1103 00:46:22,230 --> 00:46:23,400 we are using different weights. 1104 00:46:23,400 --> 00:46:24,610 You can look at it that way. 1105 00:46:24,610 --> 00:46:26,680 So if we have block size of eight in my example, 1106 00:46:26,679 --> 00:46:27,762 we would have eight nodes. 1107 00:46:27,762 --> 00:46:29,309 Here we have 2, 4, 6. 1108 00:46:29,309 --> 00:46:30,989 OK, so we'd have eight nodes. 1109 00:46:30,989 --> 00:46:33,042 They would be connected in-- 1110 00:46:33,043 --> 00:46:35,460 you lay them out, and you only connect from left to right. 1111 00:46:35,460 --> 00:46:37,860 [INAUDIBLE] 1112 00:46:42,635 --> 00:46:44,010 Why would they connect-- usually, 1113 00:46:44,010 --> 00:46:46,410 the connections don't change as a function of the data 1114 00:46:46,409 --> 00:46:47,460 or something like that-- 1115 00:46:47,460 --> 00:46:51,990 [INAUDIBLE] 1116 00:47:00,293 --> 00:47:02,210 I don't think I've seen a single example where 1117 00:47:02,210 --> 00:47:03,139 the connectivity changes dynamically 1118 00:47:03,139 --> 00:47:04,021 in the function data. 1119 00:47:04,021 --> 00:47:05,480 Usually, the connectivity is fixed. 1120 00:47:05,480 --> 00:47:07,610 If you have an encoder, and you're training a BERT, 1121 00:47:07,610 --> 00:47:09,500 you have how many tokens you want, 1122 00:47:09,500 --> 00:47:11,639 and they are fully connected. 1123 00:47:11,639 --> 00:47:13,539 And if you have a decoder-only model, 1124 00:47:13,539 --> 00:47:15,289 you have this triangular thing, and if you 1125 00:47:15,289 --> 00:47:16,748 have encoder-decoder, then you have 1126 00:47:16,748 --> 00:47:21,269 awkwardly two pools of nodes. 1127 00:47:21,269 --> 00:47:21,769 Yeah. 1128 00:47:24,639 --> 00:47:25,230 Go ahead. 1129 00:47:25,230 --> 00:47:45,010 [INAUDIBLE] I wonder, you know much more about this 1130 00:47:45,010 --> 00:47:46,604 than I know. 1131 00:47:46,603 --> 00:48:00,629 But do you have a sense of like if you ran [INAUDIBLE] 1132 00:48:00,630 --> 00:48:08,555 In my head, I'm thinking [INAUDIBLE] but then you also 1133 00:48:08,554 --> 00:48:13,099 have different things for one or more of [INAUDIBLE]---- 1134 00:48:13,099 --> 00:48:15,000 Yeah, it's really hard to say, so that's 1135 00:48:15,000 --> 00:48:17,219 why I think this paper is so interesting because like, yeah, 1136 00:48:17,219 --> 00:48:18,569 usually, you'd see like the path, 1137 00:48:18,570 --> 00:48:19,680 and maybe they had path internally. 1138 00:48:19,679 --> 00:48:20,981 They just didn't publish it. 1139 00:48:20,981 --> 00:48:23,565 All you can see is things that didn't look like a transformer. 1140 00:48:23,565 --> 00:48:26,250 I mean, you have ResNets, which have lots of this. 1141 00:48:26,250 --> 00:48:29,820 But a ResNet would be like this, but there's 1142 00:48:29,820 --> 00:48:31,200 no self-attention component. 1143 00:48:31,199 --> 00:48:35,579 But the MLP is there kind of in a ResNet. 1144 00:48:35,579 --> 00:48:37,710 So a ResNet looks very much like this 1145 00:48:37,710 --> 00:48:40,349 except there's no-- you can use layer norms in ResNets, 1146 00:48:40,349 --> 00:48:41,219 I believe, as well. 1147 00:48:41,219 --> 00:48:43,509 Typically, sometimes, they can be batch norms. 1148 00:48:43,510 --> 00:48:45,210 So it is kind of like a ResNet. 1149 00:48:45,210 --> 00:48:47,190 It is like they took a ResNet, and they 1150 00:48:47,190 --> 00:48:50,369 put in a self-attention block in addition 1151 00:48:50,369 --> 00:48:52,139 to the preexisting MLP block, which 1152 00:48:52,139 --> 00:48:53,741 is kind of like convolutions. 1153 00:48:53,742 --> 00:48:55,575 And MLP was strictly speaking deconvolution, 1154 00:48:55,574 --> 00:48:59,099 one by one convolution, but I think 1155 00:48:59,099 --> 00:49:04,110 the idea is similar in that MLP is just like a typical weights, 1156 00:49:04,110 --> 00:49:06,210 nonlinearity weights operation. 1157 00:49:11,047 --> 00:49:13,089 But I will say, yeah, this is kind of interesting 1158 00:49:13,090 --> 00:49:15,968 because a lot of work is not there, 1159 00:49:15,967 --> 00:49:17,634 and then they give you this transformer. 1160 00:49:17,635 --> 00:49:18,820 And then it turns out 5 years later, 1161 00:49:18,820 --> 00:49:20,860 it's not changed, even though everyone's trying to change it. 1162 00:49:20,860 --> 00:49:23,095 So it's interesting to me that it's like a package, 1163 00:49:23,094 --> 00:49:25,487 in like a package, which I think is really 1164 00:49:25,487 --> 00:49:26,529 interesting historically. 1165 00:49:26,530 --> 00:49:30,100 And I also talked to paper authors, 1166 00:49:30,099 --> 00:49:32,116 and they were unaware of the impact 1167 00:49:32,117 --> 00:49:33,950 that the transformer would have at the time. 1168 00:49:33,949 --> 00:49:37,419 So when you read this paper, actually, it's unfortunate 1169 00:49:37,420 --> 00:49:39,548 because this is the paper that changed everything, 1170 00:49:39,547 --> 00:49:41,589 but when people read it, it's like question marks 1171 00:49:41,590 --> 00:49:45,100 because it reads like a pretty random machine translation 1172 00:49:45,099 --> 00:49:46,139 paper. 1173 00:49:46,139 --> 00:49:47,304 It's like, oh, we're doing machine translation. 1174 00:49:47,304 --> 00:49:48,596 Oh, here's a cool architecture. 1175 00:49:48,597 --> 00:49:51,265 OK, great, good results. 1176 00:49:51,264 --> 00:49:53,589 It doesn't know what's going to happen. 1177 00:49:53,590 --> 00:49:56,260 [LAUGHS] And so when people read it today, 1178 00:49:56,260 --> 00:50:00,550 I think they're confused potentially. 1179 00:50:00,550 --> 00:50:02,152 I will have some tweets at the end, 1180 00:50:02,152 --> 00:50:03,610 but I think I would have renamed it 1181 00:50:03,610 --> 00:50:08,755 with the benefit of hindsight of like, well, I'll get to it. 1182 00:50:08,755 --> 00:50:15,112 [INAUDIBLE] 1183 00:50:20,920 --> 00:50:22,990 Yeah, I think that's a good question as well. 1184 00:50:22,989 --> 00:50:24,719 Currently, I mean, I certainly don't 1185 00:50:24,719 --> 00:50:27,329 love the autoregressive modeling approach. 1186 00:50:27,329 --> 00:50:29,250 I think it's kind of weird to sample a token 1187 00:50:29,250 --> 00:50:31,195 and then commit to it. 1188 00:50:31,195 --> 00:50:36,809 So maybe there are some ways, some hybrids 1189 00:50:36,809 --> 00:50:38,309 with the Fusion as an example, which 1190 00:50:38,309 --> 00:50:41,409 I think would be really cool, or we'll 1191 00:50:41,409 --> 00:50:44,319 find some other ways to edit the sequences later but still 1192 00:50:44,320 --> 00:50:47,177 in our regressive framework. 1193 00:50:47,177 --> 00:50:49,510 But I think the Fusion is like an up and coming modeling 1194 00:50:49,510 --> 00:50:51,677 approach that I personally find much more appealing. 1195 00:50:51,677 --> 00:50:54,190 When I sample text, I don't go chunk, chunk, chunk, 1196 00:50:54,190 --> 00:50:55,365 and commit. 1197 00:50:55,364 --> 00:50:58,299 I do a draft one, and then I do a better draft two. 1198 00:50:58,300 --> 00:51:00,880 And that feels like a diffusion process. 1199 00:51:00,880 --> 00:51:02,480 So that would be my hope. 1200 00:51:05,449 --> 00:51:07,759 OK, also a question. 1201 00:51:07,760 --> 00:51:20,338 So yeah, you'd think the [INAUDIBLE] 1202 00:51:20,338 --> 00:51:21,880 And then once we have the edge rates, 1203 00:51:21,880 --> 00:51:23,894 we just have to multiply it by the values, 1204 00:51:23,894 --> 00:51:25,269 and then you just [INAUDIBLE] it. 1205 00:51:25,269 --> 00:51:27,159 Yes, yeah, it's right. 1206 00:51:27,159 --> 00:51:30,339 And you think there's MLG within graph neural networks 1207 00:51:30,340 --> 00:51:32,590 and they'll potentially-- 1208 00:51:32,590 --> 00:51:34,990 I find the graph neural networks like a confusing term 1209 00:51:34,989 --> 00:51:38,209 because, I mean, yeah, previously, 1210 00:51:38,210 --> 00:51:40,262 there, was this notion of-- 1211 00:51:40,262 --> 00:51:42,429 I feel like maybe today everything is a graph neural 1212 00:51:42,429 --> 00:51:44,799 network because a transformer is a graph neural network 1213 00:51:44,800 --> 00:51:45,760 processor. 1214 00:51:45,760 --> 00:51:48,260 The native representation that the transformer operates over 1215 00:51:48,260 --> 00:51:51,680 is sets that are connected by edges in a direct way. 1216 00:51:51,679 --> 00:51:55,636 And so that's the native representation, and then, yeah. 1217 00:51:55,637 --> 00:51:57,720 OK, I should go on because I still have 30 slides. 1218 00:51:57,719 --> 00:51:59,539 [INAUDIBLE] 1219 00:52:08,099 --> 00:52:11,339 Oh yeah, yeah, the root DE, I think, it basically 1220 00:52:11,340 --> 00:52:14,130 like if you're initializing with random weights 1221 00:52:14,130 --> 00:52:17,140 setup from a [INAUDIBLE] as your dimension size grows, 1222 00:52:17,139 --> 00:52:19,349 so does your values, the variance grows. 1223 00:52:19,349 --> 00:52:23,400 And then your softmax will just become the one half vector. 1224 00:52:23,400 --> 00:52:25,410 So it's just a way to control the variance 1225 00:52:25,409 --> 00:52:28,049 and bring it to always be in a good range for softmax 1226 00:52:28,050 --> 00:52:31,670 and nice diffused distribution. 1227 00:52:31,670 --> 00:52:37,869 OK, so it's almost like an initialization thing. 1228 00:52:37,869 --> 00:52:41,469 OK, so transformers have been applied 1229 00:52:41,469 --> 00:52:44,319 to all the other fields, and the way this was done 1230 00:52:44,320 --> 00:52:46,900 is in my opinion, ridiculous ways 1231 00:52:46,900 --> 00:52:49,389 honestly because I was a computer vision person, 1232 00:52:49,389 --> 00:52:51,400 and you have ComNets, and they make sense. 1233 00:52:51,400 --> 00:52:53,840 So what we're doing now with VITs as an example is 1234 00:52:53,840 --> 00:52:56,215 you take an image and you chop it up into little squares. 1235 00:52:56,215 --> 00:52:57,802 And then those squares, literally, 1236 00:52:57,802 --> 00:52:59,260 feed into a transformer, and that's 1237 00:52:59,260 --> 00:53:01,900 it, which is kind of ridiculous. 1238 00:53:01,900 --> 00:53:06,389 And so, I mean, yeah, and so the transformer 1239 00:53:06,389 --> 00:53:08,670 doesn't even, in the simplest case, really know where 1240 00:53:08,670 --> 00:53:10,470 these patches might come from. 1241 00:53:10,469 --> 00:53:12,929 They are usually positionally encoded, 1242 00:53:12,929 --> 00:53:16,379 but it has to rediscover a lot of the structure, 1243 00:53:16,380 --> 00:53:19,180 I think, of them in some ways. 1244 00:53:19,179 --> 00:53:23,089 And it's kind of weird to approach it that way. 1245 00:53:23,090 --> 00:53:25,579 But it's just the simplest baseline 1246 00:53:25,579 --> 00:53:27,672 of just chomping up big images into small squares 1247 00:53:27,672 --> 00:53:29,839 and feeding them in as the individual nodes actually 1248 00:53:29,840 --> 00:53:30,620 works fairly well. 1249 00:53:30,619 --> 00:53:32,690 And then this is in a transformer encoder, 1250 00:53:32,690 --> 00:53:34,760 so all the patches are talking to each other 1251 00:53:34,760 --> 00:53:36,960 throughout the entire transformer. 1252 00:53:36,960 --> 00:53:39,494 And the number of nodes here would be like nine. 1253 00:53:42,284 --> 00:53:44,909 Also, in speech recognition, you just take your melSpectrogram, 1254 00:53:44,909 --> 00:53:46,937 and you chop it up into slices and you feed them 1255 00:53:46,938 --> 00:53:47,730 into a transformer. 1256 00:53:47,730 --> 00:53:49,920 So there was paper like this, but also Whisper. 1257 00:53:49,920 --> 00:53:51,720 Whisper is a copy-paste transformer. 1258 00:53:51,719 --> 00:53:55,199 If you saw Whisper from OpenAI, you just chop up melSpectrogram 1259 00:53:55,199 --> 00:53:57,547 and feed it into a transformer and then pretend 1260 00:53:57,547 --> 00:53:58,589 you're dealing with text. 1261 00:53:58,590 --> 00:54:00,870 And it works very well. 1262 00:54:00,869 --> 00:54:03,692 Decision transformer in RL, you take your states, actions, 1263 00:54:03,693 --> 00:54:05,610 and reward that you experience in environment, 1264 00:54:05,610 --> 00:54:07,693 and you just pretend it's a language. 1265 00:54:07,693 --> 00:54:09,610 Then you start to model the sequences of that, 1266 00:54:09,610 --> 00:54:11,640 and then you can use that for planning later. 1267 00:54:11,639 --> 00:54:13,319 That works really well. 1268 00:54:13,320 --> 00:54:15,382 Even things AlphaFold, so we were briefly 1269 00:54:15,382 --> 00:54:17,590 talking about molecules and how you can plug them in. 1270 00:54:17,590 --> 00:54:19,507 So at the heart of AlphaFold, computationally, 1271 00:54:19,507 --> 00:54:21,907 is also a transformer. 1272 00:54:21,907 --> 00:54:23,949 One thing I wanted to also say about transformers 1273 00:54:23,949 --> 00:54:26,289 is I find that they're very flexible, 1274 00:54:26,289 --> 00:54:28,150 and I really enjoy that. 1275 00:54:28,150 --> 00:54:31,228 I'll give you an example from Tesla. 1276 00:54:31,228 --> 00:54:32,769 You have a ComNet that takes an image 1277 00:54:32,769 --> 00:54:34,300 and makes predictions about the image. 1278 00:54:34,300 --> 00:54:35,967 And then the big question is, how do you 1279 00:54:35,967 --> 00:54:37,269 feed in extra information? 1280 00:54:37,269 --> 00:54:38,920 And it's not always trivial like say, I 1281 00:54:38,920 --> 00:54:40,389 had additional information that I 1282 00:54:40,389 --> 00:54:43,480 want to inform that I want the outputs to be informed by. 1283 00:54:43,480 --> 00:54:45,112 Maybe I have other sensors like Radar. 1284 00:54:45,112 --> 00:54:47,320 Maybe I have some map information, or a vehicle type, 1285 00:54:47,320 --> 00:54:48,085 or some audio. 1286 00:54:48,085 --> 00:54:50,710 And the question is, how do you feed information into a ComNet? 1287 00:54:50,710 --> 00:54:52,329 Like where do you feed it in? 1288 00:54:52,329 --> 00:54:54,429 Do you concatenate it? 1289 00:54:54,429 --> 00:54:55,210 Do you add it? 1290 00:54:55,210 --> 00:54:56,349 At what stage? 1291 00:54:56,349 --> 00:54:58,202 And so with a transformer, it's much easier 1292 00:54:58,202 --> 00:55:00,369 because you just take whatever you want, you chop it 1293 00:55:00,369 --> 00:55:02,500 up into pieces, and you feed it in with a set 1294 00:55:02,500 --> 00:55:03,500 of what you had before. 1295 00:55:03,500 --> 00:55:04,690 And you let the self-attention figure out 1296 00:55:04,690 --> 00:55:06,106 how everything should communicate. 1297 00:55:06,106 --> 00:55:07,719 And that actually apparently works. 1298 00:55:07,719 --> 00:55:10,119 So just chop up everything and throw it into the mix 1299 00:55:10,119 --> 00:55:11,739 is like the way. 1300 00:55:11,739 --> 00:55:15,759 And it frees neural nets from this burgeon 1301 00:55:15,760 --> 00:55:19,332 of Euclidean space, where previously you 1302 00:55:19,331 --> 00:55:21,789 had to arrange your computation to conform to the Euclidean 1303 00:55:21,789 --> 00:55:25,304 space or three dimensions of how you're laying out the compute. 1304 00:55:25,304 --> 00:55:26,679 Like the compute actually kind of 1305 00:55:26,679 --> 00:55:29,859 happens in almost like 3D space if you think about it. 1306 00:55:29,860 --> 00:55:32,050 But in attention, everything is just sets. 1307 00:55:32,050 --> 00:55:33,730 So it's a very flexible framework, 1308 00:55:33,730 --> 00:55:35,530 and you can just throw in stuff into your conditioning set. 1309 00:55:35,530 --> 00:55:37,155 And everything just self-attended over. 1310 00:55:37,155 --> 00:55:39,595 So it's quite beautiful from that perspective. 1311 00:55:39,594 --> 00:55:43,219 OK, so now what exactly makes transformers so effective? 1312 00:55:43,219 --> 00:55:44,719 I think a good example of this comes 1313 00:55:44,719 --> 00:55:48,230 from the GPT-3 paper, which I encourage people to read. 1314 00:55:48,230 --> 00:55:50,280 Language Models of Few-Shot Learners. 1315 00:55:50,280 --> 00:55:52,280 I would have probably renamed this a little bit. 1316 00:55:52,280 --> 00:55:54,380 I would have said something like transformers 1317 00:55:54,380 --> 00:55:57,769 are capable of in-context learning or meta-learning. 1318 00:55:57,769 --> 00:56:00,097 That's like what makes them really special. 1319 00:56:00,097 --> 00:56:02,180 So basically the setting that they're working with 1320 00:56:02,179 --> 00:56:03,679 is, OK, I have some context, and I'm 1321 00:56:03,679 --> 00:56:04,887 trying-- like say, a passage. 1322 00:56:04,887 --> 00:56:06,335 This is just one example of many. 1323 00:56:06,335 --> 00:56:08,840 I have a passage, and I'm asking questions about it. 1324 00:56:08,840 --> 00:56:12,762 And then as part of the context in the prompt, 1325 00:56:12,762 --> 00:56:14,470 I'm giving the questions and the answers. 1326 00:56:14,469 --> 00:56:16,009 So I'm giving one example of question-answer, 1327 00:56:16,010 --> 00:56:17,468 another example of question-answer, 1328 00:56:17,467 --> 00:56:19,889 another example of question-answer, and so on. 1329 00:56:19,889 --> 00:56:21,799 And this becomes-- 1330 00:56:21,800 --> 00:56:24,289 Oh yeah, people are going to have to leave soon, huh? 1331 00:56:24,289 --> 00:56:25,634 OK, is this really important? 1332 00:56:25,635 --> 00:56:26,177 Let me think. 1333 00:56:29,454 --> 00:56:31,329 OK, so what's really interesting is basically 1334 00:56:31,329 --> 00:56:35,380 like with more examples given in a context, 1335 00:56:35,380 --> 00:56:37,200 the accuracy improves. 1336 00:56:37,199 --> 00:56:39,199 And so what that can set is that the transformer 1337 00:56:39,199 --> 00:56:42,159 is able to somehow learn in the activations 1338 00:56:42,159 --> 00:56:43,629 without doing any gradient descent 1339 00:56:43,630 --> 00:56:45,260 in a typical fine-tuning fashion. 1340 00:56:45,260 --> 00:56:48,460 So if you fine-tune, you have to give an example and the answer, 1341 00:56:48,460 --> 00:56:51,246 and you fine-tune it, using gradient descent. 1342 00:56:51,246 --> 00:56:53,079 But it looks like the transformer internally 1343 00:56:53,079 --> 00:56:54,519 in its weights is doing something 1344 00:56:54,519 --> 00:56:56,050 that looks like potentially gradient, some kind 1345 00:56:56,050 --> 00:56:57,430 of a metalearning in the weights of the transformer 1346 00:56:57,429 --> 00:56:59,049 as it is reading the prompt. 1347 00:56:59,050 --> 00:57:01,678 And so in this paper, they go into, OK, 1348 00:57:01,677 --> 00:57:03,969 distinguishing this outer loop with stochastic gradient 1349 00:57:03,969 --> 00:57:06,302 descent in this inner loop of the intercontext learning. 1350 00:57:06,302 --> 00:57:08,679 So the inner loop is as the transformer is reading 1351 00:57:08,679 --> 00:57:12,339 the sequence almost and the outer loop is the training 1352 00:57:12,340 --> 00:57:14,032 by gradient descent. 1353 00:57:14,032 --> 00:57:15,490 So basically, there's some training 1354 00:57:15,489 --> 00:57:17,019 happening in the activations of the transformer 1355 00:57:17,019 --> 00:57:18,730 as it is consuming a sequence that 1356 00:57:18,730 --> 00:57:21,099 may be very much looks like gradient descent. 1357 00:57:21,099 --> 00:57:23,307 And so there are some recent papers that hint at this 1358 00:57:23,307 --> 00:57:23,929 and study it. 1359 00:57:23,929 --> 00:57:25,387 And so as an example, in this paper 1360 00:57:25,387 --> 00:57:28,719 here, they propose something called the draw operator. 1361 00:57:28,719 --> 00:57:32,072 And they argue that the raw operator is implemented 1362 00:57:32,072 --> 00:57:33,489 by transformer, and then they show 1363 00:57:33,489 --> 00:57:35,289 that you can implement things like ridge regression 1364 00:57:35,289 --> 00:57:36,599 on top of the raw operator. 1365 00:57:36,599 --> 00:57:39,011 And so this is giving-- 1366 00:57:39,012 --> 00:57:40,720 There are papers hinting that maybe there 1367 00:57:40,719 --> 00:57:42,927 is some thing that looks like gradient-based learning 1368 00:57:42,927 --> 00:57:45,250 inside the activations of the transformer. 1369 00:57:45,250 --> 00:57:47,590 And I think this is not impossible to think through 1370 00:57:47,590 --> 00:57:49,720 because what is gradient-based learning? 1371 00:57:49,719 --> 00:57:52,179 Overpass, backward pass, and then update. 1372 00:57:52,179 --> 00:57:54,250 Oh, that looks like a ResNet, right, 1373 00:57:54,250 --> 00:57:57,099 because you're adding to the weights. 1374 00:57:57,099 --> 00:57:59,511 So the start of initial random set of weights, 1375 00:57:59,512 --> 00:58:01,720 forward pass, backward pass, and update your weights, 1376 00:58:01,719 --> 00:58:04,096 and then forward pass, backward pass, update the weights. 1377 00:58:04,097 --> 00:58:04,930 Looks like a ResNet. 1378 00:58:04,929 --> 00:58:10,179 Transformer is a ResNet, so much more hand-wavey, 1379 00:58:10,179 --> 00:58:11,889 but basically, some papers are trying 1380 00:58:11,889 --> 00:58:14,525 to hint at why that would be potentially possible. 1381 00:58:14,525 --> 00:58:16,900 And then I have a bunch of tweets I just copy-pasted here 1382 00:58:16,900 --> 00:58:18,639 in the end. 1383 00:58:18,639 --> 00:58:20,519 This was like meant for general consumption, 1384 00:58:20,519 --> 00:58:22,900 so they're a bit more high-level and hypey a little bit. 1385 00:58:22,900 --> 00:58:26,079 But I'm talking about why this architecture is so interesting 1386 00:58:26,079 --> 00:58:27,994 and why potentially it became so popular. 1387 00:58:27,994 --> 00:58:29,619 And I think it simultaneously optimizes 1388 00:58:29,619 --> 00:58:31,464 three properties that, I think, are very desirable. 1389 00:58:31,465 --> 00:58:33,130 Number one, the transformer is very 1390 00:58:33,130 --> 00:58:35,865 expressive in the forward pass. 1391 00:58:35,865 --> 00:58:37,509 It sort of like it's able to implement 1392 00:58:37,510 --> 00:58:39,552 very interesting functions, potentially functions 1393 00:58:39,552 --> 00:58:41,920 that can even do meta-learning. 1394 00:58:41,920 --> 00:58:43,659 Number two, it is very optimizable thanks 1395 00:58:43,659 --> 00:58:45,429 to things like residual connections, layer nodes, 1396 00:58:45,429 --> 00:58:45,940 and so on. 1397 00:58:45,940 --> 00:58:47,731 And number three, it's extremely efficient. 1398 00:58:47,731 --> 00:58:49,929 This is not always appreciated, but the transformer, 1399 00:58:49,929 --> 00:58:51,554 if you look at the computational graph, 1400 00:58:51,554 --> 00:58:53,649 is a shallow, wide network, which 1401 00:58:53,650 --> 00:58:56,224 is perfect to take advantage of the parallelism of GPUs. 1402 00:58:56,224 --> 00:58:58,599 So I think the transformer was designed very deliberately 1403 00:58:58,599 --> 00:59:00,730 to run efficiently on GPUs. 1404 00:59:00,730 --> 00:59:02,650 There's previous work like neural GPU 1405 00:59:02,650 --> 00:59:05,680 that I really enjoy as well, which is really just 1406 00:59:05,679 --> 00:59:08,559 like how do we design neural nets that are efficient on GPUs 1407 00:59:08,559 --> 00:59:10,420 and thinking backwards from the constraints of the hardware, 1408 00:59:10,420 --> 00:59:11,740 which I think is a very interesting way 1409 00:59:11,739 --> 00:59:12,489 to think about it. 1410 00:59:17,929 --> 00:59:21,789 Oh yeah, so here, I'm saying, I probably would have called-- 1411 00:59:21,789 --> 00:59:24,489 I probably would've called the transformer a general purpose 1412 00:59:24,489 --> 00:59:27,819 efficient optimizable computer instead of attention 1413 00:59:27,820 --> 00:59:28,570 is all you need. 1414 00:59:28,570 --> 00:59:31,930 That's what I would have maybe in hindsight called that paper. 1415 00:59:31,929 --> 00:59:37,349 It's proposing a model that is very general purpose, so 1416 00:59:37,349 --> 00:59:38,539 forward passes, expressive. 1417 00:59:38,539 --> 00:59:40,759 It's very efficient in terms of GPU usage 1418 00:59:40,760 --> 00:59:44,720 and is easily optimizable by gradient descent and trains 1419 00:59:44,719 --> 00:59:46,511 very nicely. 1420 00:59:46,512 --> 00:59:48,730 And then I have some other hype tweets here. 1421 00:59:51,489 --> 00:59:53,339 Anyway, so you can read them later. 1422 00:59:53,340 --> 00:59:55,090 But I think this one is maybe interesting. 1423 00:59:55,090 --> 00:59:58,360 So if previous neural nets are special purpose computers 1424 00:59:58,360 --> 01:00:00,490 designed for a specific task, GPT 1425 01:00:00,489 --> 01:00:03,789 is a general purpose computer, reconfigurable at runtime 1426 01:00:03,789 --> 01:00:06,039 to run natural language programs. 1427 01:00:06,039 --> 01:00:08,920 So the programs are given as prompts, 1428 01:00:08,920 --> 01:00:12,220 and then GPT runs the program by completing the document. 1429 01:00:12,219 --> 01:00:16,959 So I really like these analogies personally to computer. 1430 01:00:16,960 --> 01:00:18,639 It's just like a powerful computer, 1431 01:00:18,639 --> 01:00:22,199 and it's optimizable by gradient descent. 1432 01:00:22,199 --> 01:00:30,614 And I don't know-- 1433 01:00:30,614 --> 01:00:31,114 OK, yeah. 1434 01:00:31,114 --> 01:00:31,614 That's it. 1435 01:00:31,614 --> 01:00:33,376 [LAUGHTER] 1436 01:00:33,376 --> 01:00:35,460 You can read the tweets later, but that's for now. 1437 01:00:35,460 --> 01:00:36,050 I'll just thank you. 1438 01:00:36,050 --> 01:00:37,050 I'll just leave this up. 1439 01:00:45,367 --> 01:00:46,659 Sorry, I just found this tweet. 1440 01:00:46,659 --> 01:00:49,599 So turns out that if you scale up the training set 1441 01:00:49,599 --> 01:00:51,940 and use a powerful enough neural net like a transformer, 1442 01:00:51,940 --> 01:00:53,815 the network becomes a kind of general purpose 1443 01:00:53,815 --> 01:00:54,720 computer over text. 1444 01:00:54,719 --> 01:00:56,527 So I think that's nice way to look at it. 1445 01:00:56,527 --> 01:00:58,569 And instead of performing a single text sequence, 1446 01:00:58,570 --> 01:01:00,340 you can design the sequence in the prompt. 1447 01:01:00,340 --> 01:01:02,230 And because the transformer is both powerful 1448 01:01:02,230 --> 01:01:05,110 but also is trained on large enough, very hard data set, 1449 01:01:05,110 --> 01:01:07,539 it becomes this general purpose text computer. 1450 01:01:07,539 --> 01:01:11,199 And so I think that's kind of interesting way to look at it. 1451 01:01:11,199 --> 01:01:13,371 Yeah. 1452 01:01:13,371 --> 01:01:16,750 [INAUDIBLE] 1453 01:02:01,289 --> 01:02:04,179 And I guess my question is [INAUDIBLE] how 1454 01:02:04,179 --> 01:02:05,597 much do you think [INAUDIBLE]? 1455 01:02:10,019 --> 01:02:25,795 really because it's mostly more efficient or [INAUDIBLE] 1456 01:02:25,795 --> 01:02:27,170 So I think there's a bit of that. 1457 01:02:27,170 --> 01:02:29,284 Yeah, so I would say RNNs in principle, 1458 01:02:29,284 --> 01:02:31,456 yes, they can implement arbitrary programs. 1459 01:02:31,456 --> 01:02:33,664 I think, it's like a useless statement to some extent 1460 01:02:33,664 --> 01:02:35,795 because they're probably-- 1461 01:02:35,795 --> 01:02:37,670 I'm not sure that they're probably expressive 1462 01:02:37,670 --> 01:02:40,369 because in a sense of power and that they can implement 1463 01:02:40,369 --> 01:02:43,069 these arbitrary functions. 1464 01:02:43,070 --> 01:02:44,250 But they're not optimizable. 1465 01:02:44,250 --> 01:02:46,250 And they're certainly not efficient because they 1466 01:02:46,250 --> 01:02:47,750 are serial computing devices. 1467 01:02:50,163 --> 01:02:51,829 So if you look at it as a compute graph, 1468 01:02:51,829 --> 01:02:58,264 RNNs are very long, thin compute graph. 1469 01:02:58,264 --> 01:03:00,650 What if you stretched out the neurons and you looked-- 1470 01:03:00,650 --> 01:03:02,255 like take all the individual neurons interconnectivity, 1471 01:03:02,255 --> 01:03:04,460 and stretch them out, and try to visualize them. 1472 01:03:04,460 --> 01:03:07,070 RNNs would be like a very long graph and that's bad. 1473 01:03:07,070 --> 01:03:08,570 And it's bad also for optimizability 1474 01:03:08,570 --> 01:03:10,980 because I don't exactly know why, 1475 01:03:10,980 --> 01:03:13,789 but just the rough intuition is when you're backpropagating, 1476 01:03:13,789 --> 01:03:15,574 you don't want to make too many steps. 1477 01:03:15,574 --> 01:03:19,384 And so transformers are a shallow wide graph, and so 1478 01:03:19,385 --> 01:03:23,983 from supervision to inputs is a very small number of hops. 1479 01:03:23,983 --> 01:03:25,400 And it's a long residual pathways, 1480 01:03:25,400 --> 01:03:26,990 which make gradients flow very easily. 1481 01:03:26,989 --> 01:03:28,364 And there's all these layer norms 1482 01:03:28,364 --> 01:03:32,509 to control the scales of all of those activations. 1483 01:03:32,510 --> 01:03:34,910 And so there's not too many hops, 1484 01:03:34,909 --> 01:03:36,980 and you're going from supervision to input 1485 01:03:36,980 --> 01:03:40,840 very quickly and just flows through the graph. 1486 01:03:40,840 --> 01:03:42,420 And it can all be done in parallel, 1487 01:03:42,420 --> 01:03:43,724 so you don't need to do this-- 1488 01:03:43,724 --> 01:03:46,029 encoder and decoder RNNs, you have to go from first word, 1489 01:03:46,030 --> 01:03:47,447 then second word, then third word. 1490 01:03:47,447 --> 01:03:49,329 But here in transformer, every single word 1491 01:03:49,329 --> 01:03:54,699 was processed completely in parallel, which is kind of a-- 1492 01:03:54,699 --> 01:03:57,039 So I think all of these are really important because all 1493 01:03:57,039 --> 01:03:57,719 of these are really important. 1494 01:03:57,719 --> 01:04:00,399 And I think number 3 is less talked about but extremely 1495 01:04:00,400 --> 01:04:03,710 important because in deep learning scale matters. 1496 01:04:03,710 --> 01:04:06,099 And so the size of the network that you can train it 1497 01:04:06,099 --> 01:04:08,509 gives you is extremely important. 1498 01:04:08,510 --> 01:04:10,580 And so if it's efficient on the current hardware, 1499 01:04:10,579 --> 01:04:11,746 then you can make it bigger. 1500 01:04:14,945 --> 01:04:17,900 You mentioned that if you do it with multiple modalities 1501 01:04:17,900 --> 01:04:19,740 of data, [INAUDIBLE]. 1502 01:04:21,722 --> 01:04:22,889 How does that actually work? 1503 01:04:22,889 --> 01:04:26,359 Do you leave the different data as different token, 1504 01:04:26,360 --> 01:04:29,220 or is it [INAUDIBLE]? 1505 01:04:29,219 --> 01:04:31,349 No, so yeah, so you take your image, 1506 01:04:31,349 --> 01:04:33,239 and you apparently chop them up into patches. 1507 01:04:33,239 --> 01:04:35,369 So there's the first thousand tokens or whatever. 1508 01:04:35,369 --> 01:04:37,139 And now, I have a special-- 1509 01:04:37,139 --> 01:04:40,934 so radar could be also, but I don't actually 1510 01:04:40,934 --> 01:04:43,920 want to make a representation of radar. 1511 01:04:43,920 --> 01:04:46,075 But you just need to chop it up and enter it. 1512 01:04:46,074 --> 01:04:47,699 And then you have to encode it somehow. 1513 01:04:47,699 --> 01:04:48,659 Like the transformer needs to know 1514 01:04:48,659 --> 01:04:49,951 that they're coming from radar. 1515 01:04:49,952 --> 01:04:52,290 So you create a special-- 1516 01:04:52,289 --> 01:04:55,706 you have some kind of a special token of that to-- 1517 01:04:55,706 --> 01:04:57,289 these radar tokens are what's slightly 1518 01:04:57,289 --> 01:04:58,759 different in the representation, and it's 1519 01:04:58,760 --> 01:05:00,050 learnable by gradient descent. 1520 01:05:00,050 --> 01:05:03,500 And like vehicle information would also 1521 01:05:03,500 --> 01:05:07,920 come in with a special embedded token that can be learned. 1522 01:05:07,920 --> 01:05:09,289 So-- 1523 01:05:09,289 --> 01:05:11,654 So how do you line those before really-- 1524 01:05:11,655 --> 01:05:12,830 Actually, but you don't. 1525 01:05:12,829 --> 01:05:13,938 It's all just a set. 1526 01:05:13,938 --> 01:05:14,480 And there's-- 1527 01:05:14,480 --> 01:05:18,744 Even the [INAUDIBLE] 1528 01:05:18,744 --> 01:05:20,869 Yeah, it's all just a set, but you can positionally 1529 01:05:20,869 --> 01:05:23,190 encode these sets if you want. 1530 01:05:23,190 --> 01:05:26,150 So positional encoding means you can 1531 01:05:26,150 --> 01:05:28,130 hardwire, for example, the coordinates 1532 01:05:28,130 --> 01:05:29,510 like using [INAUDIBLE]. 1533 01:05:29,510 --> 01:05:31,310 You can hardwire that, but it's better 1534 01:05:31,309 --> 01:05:33,380 if you don't hardwire the position. 1535 01:05:33,380 --> 01:05:34,768 It's just a vector that is always 1536 01:05:34,768 --> 01:05:35,934 hanging out the dislocation. 1537 01:05:35,934 --> 01:05:37,909 Whatever content is there, it just adds on it. 1538 01:05:37,909 --> 01:05:39,289 And this vector is trainable by background. 1539 01:05:39,289 --> 01:05:40,164 That's how you do it. 1540 01:05:43,458 --> 01:05:43,958 Good point. 1541 01:05:43,958 --> 01:05:45,994 I don't really like the [INAUDIBLE].. 1542 01:05:48,735 --> 01:05:51,400 They seem to work, but it seems like they're sometimes 1543 01:05:51,400 --> 01:06:08,867 [INAUDIBLE] 1544 01:06:08,867 --> 01:06:10,659 I'm not sure if I understand your question. 1545 01:06:10,659 --> 01:06:11,295 [LAUGHTER] 1546 01:06:11,295 --> 01:06:12,700 So I mean the positional encoders 1547 01:06:12,699 --> 01:06:14,619 like they're actually like not-- 1548 01:06:14,619 --> 01:06:16,969 OK, so they have very little inductive bias or something 1549 01:06:16,969 --> 01:06:17,469 like that. 1550 01:06:17,469 --> 01:06:19,636 They're just vectors hanging out in location always, 1551 01:06:19,637 --> 01:06:23,900 and you're trying to help the network in some way. 1552 01:06:23,900 --> 01:06:28,710 And I think the intuition is good, 1553 01:06:28,710 --> 01:06:30,490 but if you have enough data, usually, 1554 01:06:30,489 --> 01:06:33,699 trying to mess with it is a bad thing. 1555 01:06:33,699 --> 01:06:35,199 Trying to enter knowledge when you 1556 01:06:35,199 --> 01:06:36,574 have enough knowledge in the data 1557 01:06:36,574 --> 01:06:38,164 set itself is not usually productive. 1558 01:06:38,164 --> 01:06:40,164 So it all really depends on what scale you want. 1559 01:06:40,164 --> 01:06:41,949 If you have infinity data, then you actually 1560 01:06:41,949 --> 01:06:43,059 want to encode less and less. 1561 01:06:43,059 --> 01:06:44,299 That turns out to work better. 1562 01:06:44,300 --> 01:06:46,269 And if you have very little data, then actually, you do 1563 01:06:46,269 --> 01:06:47,230 want to encode some biases. 1564 01:06:47,230 --> 01:06:49,179 And maybe if you have a much smaller data set, then 1565 01:06:49,179 --> 01:06:50,596 maybe convolutions are a good idea 1566 01:06:50,597 --> 01:06:55,269 because you actually have this bias coming from your filters. 1567 01:06:55,269 --> 01:06:58,969 But I think-- so the transformer is extremely general, 1568 01:06:58,969 --> 01:07:01,230 but there are ways to mess with the encodings 1569 01:07:01,230 --> 01:07:02,271 to put in more structure. 1570 01:07:02,271 --> 01:07:05,039 Like you could, for example, encode [INAUDIBLE] and fix it, 1571 01:07:05,039 --> 01:07:07,164 or you could actually go to the attention mechanism 1572 01:07:07,164 --> 01:07:10,831 and say, OK, if my image is chopped up into patches, 1573 01:07:10,831 --> 01:07:13,039 this patch can only communicate to this neighborhood. 1574 01:07:13,039 --> 01:07:15,170 And you just do that in the attention matrix, 1575 01:07:15,170 --> 01:07:18,152 you just mask out whatever you don't want to communicate. 1576 01:07:18,152 --> 01:07:19,610 And so people really play with this 1577 01:07:19,610 --> 01:07:22,724 because the full attention is inefficient. 1578 01:07:22,724 --> 01:07:25,159 So they will intersperse, for example, layers 1579 01:07:25,159 --> 01:07:26,869 that only communicate in little patches 1580 01:07:26,869 --> 01:07:28,639 and then layers that communicate globally. 1581 01:07:28,639 --> 01:07:30,679 And they will do all kinds of tricks like that. 1582 01:07:30,679 --> 01:07:33,922 So you can slowly bring in more inductive bias. 1583 01:07:33,922 --> 01:07:35,630 You would do it, but the inductive biases 1584 01:07:35,630 --> 01:07:38,990 are like they're factored out from the core transformer. 1585 01:07:38,989 --> 01:07:41,957 And they are factored out, and the interconnectivity 1586 01:07:41,958 --> 01:07:42,500 of the nodes. 1587 01:07:42,500 --> 01:07:44,909 And they are factored out in the positionally-- 1588 01:07:44,909 --> 01:07:49,657 and you can mess with this for computation. 1589 01:07:49,657 --> 01:08:01,067 [INAUDIBLE] 1590 01:08:02,530 --> 01:08:06,407 So there's probably about 200 papers on this now if not more. 1591 01:08:06,407 --> 01:08:07,990 They're kind of hard to keep track of. 1592 01:08:07,989 --> 01:08:10,119 Honestly, like my Safari browser, which is-- oh, 1593 01:08:10,119 --> 01:08:13,750 it's all up on my computer, like 200 open tabs. 1594 01:08:13,750 --> 01:08:20,579 But yes, I'm not even sure if I want 1595 01:08:20,579 --> 01:08:23,609 to pick my favorite honestly. 1596 01:08:23,609 --> 01:08:29,904 Yeah, [INAUDIBLE] 1597 01:08:42,600 --> 01:08:45,146 Maybe you can use a transformer like that [INAUDIBLE] 1598 01:08:45,145 --> 01:08:46,978 The other one that I actually like even more 1599 01:08:46,979 --> 01:08:49,289 is potentially, keep the context length fixed 1600 01:08:49,289 --> 01:08:53,086 but allow the network to somehow use a scratch pad. 1601 01:08:53,087 --> 01:08:55,545 And so the way this works is you will teach the transformer 1602 01:08:55,545 --> 01:08:57,869 somehow via examples in [INAUDIBLE] hey, 1603 01:08:57,869 --> 01:09:00,265 you actually have a scratch pad. 1604 01:09:00,265 --> 01:09:01,890 Basically, you can't remember too much. 1605 01:09:01,890 --> 01:09:02,939 Your context line is finite. 1606 01:09:02,939 --> 01:09:04,200 But you can use a scratch pad. 1607 01:09:04,199 --> 01:09:06,426 And you do that by emitting a start scratch pad, 1608 01:09:06,426 --> 01:09:08,759 and then writing whatever you want to remember, and then 1609 01:09:08,760 --> 01:09:10,079 end scratch pad. 1610 01:09:10,079 --> 01:09:12,750 And then you continue with whatever you want. 1611 01:09:12,750 --> 01:09:14,345 And then later when it's decoding, 1612 01:09:14,345 --> 01:09:15,720 you actually have special objects 1613 01:09:15,720 --> 01:09:18,090 that when you detect start scratch pad, 1614 01:09:18,090 --> 01:09:19,739 you will like save whatever it puts 1615 01:09:19,739 --> 01:09:22,639 in there in like external thing and allow it to attend over it. 1616 01:09:22,640 --> 01:09:25,140 So basically, you can teach the transformer just dynamically 1617 01:09:25,140 --> 01:09:27,479 because it's so meta-learned. 1618 01:09:27,479 --> 01:09:30,060 You can teach it dynamically to use other gizmos and gadgets 1619 01:09:30,060 --> 01:09:31,927 and allow it to expand its memory that way 1620 01:09:31,926 --> 01:09:32,759 if that makes sense. 1621 01:09:32,760 --> 01:09:35,533 It's just like human learning to use a notepad, right. 1622 01:09:35,533 --> 01:09:37,200 You don't have to keep it in your brain. 1623 01:09:37,199 --> 01:09:39,119 So keeping things in your brain is like the context line 1624 01:09:39,119 --> 01:09:39,994 from the transformer. 1625 01:09:39,994 --> 01:09:42,119 But maybe we can just give it a notebook. 1626 01:09:42,119 --> 01:09:45,149 And then it can query the notebook, and read from it, 1627 01:09:45,149 --> 01:09:46,396 and write to it. 1628 01:09:46,396 --> 01:09:48,689 [INAUDIBLE] transformer to plug in another transformer. 1629 01:09:48,689 --> 01:09:50,645 [LAUGHTER] 1630 01:09:53,090 --> 01:09:58,140 [INAUDIBLE] 1631 01:10:09,140 --> 01:10:10,520 I don't know if I detected that. 1632 01:10:10,520 --> 01:10:12,853 I feel like-- did you feel like there was more than just 1633 01:10:12,853 --> 01:10:14,720 a long prompt that's unfolding? 1634 01:10:14,720 --> 01:10:19,930 Yeah, [INAUDIBLE] 1635 01:10:19,930 --> 01:10:22,960 I didn't try extensively, but I did see a [INAUDIBLE] event. 1636 01:10:22,960 --> 01:10:25,270 And I felt like the block size was just moved. 1637 01:10:28,162 --> 01:10:28,829 Maybe I'm wrong. 1638 01:10:28,829 --> 01:10:31,199 I don't actually know about the internals of ChatGPT. 1639 01:10:31,199 --> 01:10:33,085 We have two online questions. 1640 01:10:33,085 --> 01:10:35,984 So one question is, "what do you think about architecture 1641 01:10:35,984 --> 01:10:38,889 [INAUDIBLE]?" 1642 01:10:38,890 --> 01:10:39,510 S4? 1643 01:10:39,510 --> 01:10:40,930 S4. 1644 01:10:40,930 --> 01:10:41,430 I'm sorry. 1645 01:10:41,430 --> 01:10:42,670 I don't know S4. 1646 01:10:42,670 --> 01:10:45,340 Which one is this one? 1647 01:10:45,340 --> 01:10:47,710 The second question, this one's a personal question. 1648 01:10:47,710 --> 01:10:49,725 "What are you going to work on next?" 1649 01:10:49,725 --> 01:10:51,364 [INAUDIBLE] 1650 01:10:51,364 --> 01:10:53,739 I mean, so right now, I'm working on things like nanoGPT. 1651 01:10:53,739 --> 01:10:54,939 Where is nanoGPT? 1652 01:10:58,765 --> 01:11:01,140 I mean, I'm going basically slightly from computer vision 1653 01:11:01,140 --> 01:11:03,869 and like computer vision-based products, do 1654 01:11:03,869 --> 01:11:05,309 a little bit in language domain. 1655 01:11:05,310 --> 01:11:06,472 Where's ChatGPT? 1656 01:11:06,471 --> 01:11:07,416 OK, nanoGPT. 1657 01:11:07,416 --> 01:11:10,215 So originally, I had minGPT, which I rewrote to nanoGPT. 1658 01:11:10,215 --> 01:11:11,819 And I'm working on this. 1659 01:11:11,819 --> 01:11:14,219 I'm trying to reproduce GPTs, and I mean, 1660 01:11:14,220 --> 01:11:16,050 I think something like ChatGPT, I think, 1661 01:11:16,050 --> 01:11:17,970 incrementally improved in a product fashion 1662 01:11:17,970 --> 01:11:19,980 would be extremely interesting. 1663 01:11:19,979 --> 01:11:23,009 And I think a lot of people feel it, 1664 01:11:23,010 --> 01:11:24,960 and that's why it went so wide. 1665 01:11:24,960 --> 01:11:28,020 So I think there's something like a Google plus 1666 01:11:28,020 --> 01:11:31,960 plus plus to build that I think is more interesting. 1667 01:11:31,960 --> 01:11:34,649 Shall we give our speaker a round of applause?