1 00:00:17,039 --> 00:00:21,760 Okay. Uh, all right. So, we'll continue 2 00:00:19,199 --> 00:00:23,519 with transformers today. Part two. Uh, 3 00:00:21,760 --> 00:00:24,960 we're going to do the second pass. Uh, 4 00:00:23,518 --> 00:00:27,439 this is going to be a deeper pass 5 00:00:24,960 --> 00:00:29,760 through the transformer stack. Um and I 6 00:00:27,439 --> 00:00:31,278 think maybe the next 30 minutes it's 7 00:00:29,760 --> 00:00:33,840 potentially the most demanding 30 8 00:00:31,278 --> 00:00:35,439 minutes of the entire course. Okay, with 9 00:00:33,840 --> 00:00:38,960 that motivational speech, let's get 10 00:00:35,439 --> 00:00:41,359 going. Okay, so quick review. Why do we 11 00:00:38,960 --> 00:00:43,520 want transformers? Because we want u we 12 00:00:41,359 --> 00:00:45,280 want an architecture that can generate 13 00:00:43,520 --> 00:00:48,879 output that has the same length as the 14 00:00:45,280 --> 00:00:50,320 input. Same length. Oh, there it is. Uh 15 00:00:48,878 --> 00:00:51,359 number two, we want to take the context 16 00:00:50,320 --> 00:00:53,520 into account and we want to take the 17 00:00:51,359 --> 00:00:55,198 order into account. And as you saw last 18 00:00:53,520 --> 00:00:57,199 time, the transformer architecture 19 00:00:55,198 --> 00:00:59,358 delivers on those three requirements. 20 00:00:57,198 --> 00:01:01,599 And so uh just a quick review, if you 21 00:00:59,359 --> 00:01:03,198 have a phrase like the train liftation, 22 00:01:01,600 --> 00:01:05,280 we have all these little arrows which 23 00:01:03,198 --> 00:01:08,079 stand for the the standalone or 24 00:01:05,280 --> 00:01:09,439 uncontextual embeddings. Uh and then 25 00:01:08,079 --> 00:01:12,079 sometimes this works. So I'm going to 26 00:01:09,438 --> 00:01:13,759 put it close to me here. 27 00:01:12,079 --> 00:01:16,640 Okay. 28 00:01:13,760 --> 00:01:17,920 All right. So um so if you here if you 29 00:01:16,640 --> 00:01:19,359 we start with either standalone 30 00:01:17,920 --> 00:01:20,879 embeddings i.e. the contextual 31 00:01:19,359 --> 00:01:22,239 embeddings uh which have been 32 00:01:20,879 --> 00:01:25,118 pre-trained or random doesn't really 33 00:01:22,239 --> 00:01:27,199 matter. If you look at the collab we did 34 00:01:25,118 --> 00:01:30,400 uh the other day we actually just start 35 00:01:27,200 --> 00:01:32,478 with random weights for the embeddings 36 00:01:30,400 --> 00:01:35,439 and then we add positional embeddings to 37 00:01:32,478 --> 00:01:38,239 them. And so you know each embedding 38 00:01:35,438 --> 00:01:39,679 each word here we take it standalone we 39 00:01:38,239 --> 00:01:41,438 take its positional embedding we just 40 00:01:39,680 --> 00:01:43,680 literally just add them up element by 41 00:01:41,438 --> 00:01:45,039 element then we get a total embedding 42 00:01:43,680 --> 00:01:48,000 and that's called the positional 43 00:01:45,040 --> 00:01:49,439 embedding of each word. Okay. And then 44 00:01:48,000 --> 00:01:51,920 uh that's what we have position input 45 00:01:49,438 --> 00:01:54,000 embeddings. So this whole thing goes 46 00:01:51,920 --> 00:01:55,359 into this transformer encoder stack and 47 00:01:54,000 --> 00:01:57,359 what pops out the other end is 48 00:01:55,359 --> 00:02:01,280 contextual embeddings. Okay. So that's 49 00:01:57,359 --> 00:02:03,920 the overall flow. Now 50 00:02:01,280 --> 00:02:06,159 we applied this uh the transformer stack 51 00:02:03,920 --> 00:02:08,080 to the word to slot classification 52 00:02:06,159 --> 00:02:10,319 problem where we basically took every 53 00:02:08,080 --> 00:02:12,400 incoming natural language query that 54 00:02:10,318 --> 00:02:14,399 comes in. We calculate its positional 55 00:02:12,400 --> 00:02:16,640 embeddings and then we run it through 56 00:02:14,400 --> 00:02:18,480 the transformer stack. uh and then we 57 00:02:16,639 --> 00:02:21,279 get contextual embeddings and then at 58 00:02:18,479 --> 00:02:22,878 this point uh since each word that comes 59 00:02:21,280 --> 00:02:24,640 out each embedding that comes out needs 60 00:02:22,878 --> 00:02:26,959 to be classified into one of 125 61 00:02:24,639 --> 00:02:29,119 possibilities we run it through a ReLU 62 00:02:26,959 --> 00:02:31,199 and then we and when we attach a softmax 63 00:02:29,120 --> 00:02:33,920 to each embedding right this is 64 00:02:31,199 --> 00:02:36,399 basically what we did last class 65 00:02:33,919 --> 00:02:39,439 um so this is the transformer encoder 66 00:02:36,400 --> 00:02:43,760 okay now actually 67 00:02:39,439 --> 00:02:43,759 any questions on this before I continue 68 00:02:48,479 --> 00:02:52,399 I was wondering why when how do you 69 00:02:50,479 --> 00:02:55,280 decide where to add more self attention 70 00:02:52,400 --> 00:02:58,000 and where to add transformer layers? You 71 00:02:55,280 --> 00:03:03,120 mentioned that chart has 96 of them. 72 00:02:58,000 --> 00:03:05,439 >> Yeah. So right so GPD3 has 90 96 73 00:03:03,120 --> 00:03:07,759 transformer blocks. Each one is a block. 74 00:03:05,439 --> 00:03:09,519 Um, so I think the question goes to do 75 00:03:07,759 --> 00:03:11,598 you add more attention heads within a 76 00:03:09,519 --> 00:03:14,158 single block or do you add lots of 77 00:03:11,598 --> 00:03:16,719 blocks? And both are good things to do. 78 00:03:14,158 --> 00:03:18,639 Um, what increasing the number of 79 00:03:16,719 --> 00:03:21,039 attention heads in a block does for you, 80 00:03:18,639 --> 00:03:23,679 it allows you to pick up more patterns 81 00:03:21,039 --> 00:03:25,919 at that level of abstraction. 82 00:03:23,680 --> 00:03:28,319 But if you add more blocks, much like 83 00:03:25,919 --> 00:03:30,798 later convolutional filters can build on 84 00:03:28,318 --> 00:03:32,798 earlier convolutional filters, you're 85 00:03:30,799 --> 00:03:34,719 going up the levels of abstraction. So 86 00:03:32,799 --> 00:03:36,480 to go to vision for instance you have 87 00:03:34,719 --> 00:03:37,680 the notion of lines and so on in the 88 00:03:36,479 --> 00:03:40,560 beginning and then you have a notion of 89 00:03:37,680 --> 00:03:42,879 edges which are two lines then you have 90 00:03:40,560 --> 00:03:45,199 you know nose eyes face and so on and so 91 00:03:42,878 --> 00:03:46,798 forth. So both are worth doing. So 92 00:03:45,199 --> 00:03:49,039 typically that's what you you typically 93 00:03:46,799 --> 00:03:52,239 find that people typically have you know 94 00:03:49,039 --> 00:03:54,400 maybe a dozen heads or you know five six 95 00:03:52,239 --> 00:03:55,840 a dozen heads. We'll see examples of how 96 00:03:54,400 --> 00:03:58,400 many heads in a couple of architectures 97 00:03:55,840 --> 00:04:01,200 later on today. And you can the more you 98 00:03:58,400 --> 00:04:02,799 go up the more uh more capable the model 99 00:04:01,199 --> 00:04:05,518 becomes. as long as you have enough data 100 00:04:02,799 --> 00:04:07,200 to train it well. So the perennial 101 00:04:05,519 --> 00:04:09,360 question of do we have enough data to 102 00:04:07,199 --> 00:04:11,039 train this large model because if you 103 00:04:09,360 --> 00:04:12,720 don't have enough data we might run into 104 00:04:11,039 --> 00:04:14,719 overfitting problems and so on. That's 105 00:04:12,719 --> 00:04:17,040 always the trade-off. 106 00:04:14,719 --> 00:04:18,720 So okay so here I just want to quickly 107 00:04:17,040 --> 00:04:20,560 switch to the collab because we didn't 108 00:04:18,720 --> 00:04:22,240 get have a chance to finish it. I'm not 109 00:04:20,560 --> 00:04:24,800 going to run it because it's going to 110 00:04:22,240 --> 00:04:27,680 take some time. So where we left off 111 00:04:24,800 --> 00:04:31,120 last time. 112 00:04:27,680 --> 00:04:32,959 Okay. So here we we basically took this 113 00:04:31,120 --> 00:04:34,959 architecture that we just saw on the 114 00:04:32,959 --> 00:04:36,560 slide and then we essentially wrote it 115 00:04:34,959 --> 00:04:37,839 as a keras model and I went through this 116 00:04:36,560 --> 00:04:39,918 model in the last class so I'm not going 117 00:04:37,839 --> 00:04:41,519 to go through it all over again. What we 118 00:04:39,918 --> 00:04:44,719 did not do last class was to actually 119 00:04:41,519 --> 00:04:47,599 run it. Um and so uh so if you actually 120 00:04:44,720 --> 00:04:50,160 run it right you can just run it for 10 121 00:04:47,600 --> 00:04:52,479 epochs just like we normally do. Give it 122 00:04:50,160 --> 00:04:53,439 data give it a bunch of epochs choose a 123 00:04:52,478 --> 00:04:55,680 particular batch size. I just 124 00:04:53,439 --> 00:04:57,519 arbitrarily chose 64. You run it for 10 125 00:04:55,680 --> 00:05:00,959 epochs and then you evaluate it on the 126 00:04:57,519 --> 00:05:03,599 test set. You get a 99% accuracy on this 127 00:05:00,959 --> 00:05:05,439 problem. One transformer stack. That's 128 00:05:03,600 --> 00:05:08,320 it. One one block rather. One block. 129 00:05:05,439 --> 00:05:09,759 That's it. And uh of course here there's 130 00:05:08,319 --> 00:05:12,560 a little trickiness going on here 131 00:05:09,759 --> 00:05:15,360 because a naive model can literally say 132 00:05:12,560 --> 00:05:17,199 every word that comes in is other. O. 133 00:05:15,360 --> 00:05:19,439 And since the O's are the majority of 134 00:05:17,199 --> 00:05:20,960 the words, it's not going to do badly, 135 00:05:19,439 --> 00:05:22,639 right? It's like having a classification 136 00:05:20,959 --> 00:05:25,038 problem in which one class is very 137 00:05:22,639 --> 00:05:26,478 predominant. So the naive way to 138 00:05:25,038 --> 00:05:27,918 actually do well is to just say every 139 00:05:26,478 --> 00:05:30,000 time something comes in, oh it's that 140 00:05:27,918 --> 00:05:32,079 majority class. The same thing happens. 141 00:05:30,000 --> 00:05:34,478 But if you then adjust for that, it 142 00:05:32,079 --> 00:05:35,918 turns out that the accuracy on the nono 143 00:05:34,478 --> 00:05:38,079 slots, which is really what you care 144 00:05:35,918 --> 00:05:40,639 about, is actually 93%. 145 00:05:38,079 --> 00:05:42,399 Which is actually pretty good. Okay. Uh 146 00:05:40,639 --> 00:05:44,319 and then I had some examples of, you 147 00:05:42,399 --> 00:05:45,758 know, lots of fun queries you can do, 148 00:05:44,319 --> 00:05:47,439 including queries where I try to break 149 00:05:45,759 --> 00:05:49,038 stuff like cheapest flight to fly from 150 00:05:47,439 --> 00:05:50,800 MIT to Mars and see what happens, you 151 00:05:49,038 --> 00:05:53,519 know, things like that. So have fun with 152 00:05:50,800 --> 00:05:56,520 it. Okay. Um, all right, back to 153 00:05:53,519 --> 00:05:56,519 PowerPoint. 154 00:05:59,439 --> 00:06:03,839 So, this is what we had. Now, what we're 155 00:06:01,199 --> 00:06:05,919 going to do in today's class, we are 156 00:06:03,839 --> 00:06:08,478 actually going to take the encoder we 157 00:06:05,918 --> 00:06:10,799 built last time and introduce three new 158 00:06:08,478 --> 00:06:11,758 complications into it. And when we 159 00:06:10,800 --> 00:06:14,079 finish introducing these three 160 00:06:11,759 --> 00:06:15,919 complications, we will actually have the 161 00:06:14,079 --> 00:06:20,079 actual transformer that was invented in 162 00:06:15,918 --> 00:06:21,918 the 2017 paper. Okay. All right. Um, the 163 00:06:20,079 --> 00:06:24,959 first tweak is the hardest tweak. So 164 00:06:21,918 --> 00:06:26,719 we'll slowly work our way to it. U so 165 00:06:24,959 --> 00:06:28,560 the thing to remember is let's review 166 00:06:26,720 --> 00:06:30,560 self attention. What is self attention? 167 00:06:28,560 --> 00:06:32,319 You have a bunch of words and we further 168 00:06:30,560 --> 00:06:34,000 said that for any particular word like 169 00:06:32,319 --> 00:06:36,000 station we want to take its positional 170 00:06:34,000 --> 00:06:38,240 embedding and then make it contextual. 171 00:06:36,000 --> 00:06:40,319 And the way we do that is by taking each 172 00:06:38,240 --> 00:06:42,240 word's embedding and then calculating 173 00:06:40,319 --> 00:06:44,160 these dot productducts between all the 174 00:06:42,240 --> 00:06:46,400 other words. And then since these dot 175 00:06:44,160 --> 00:06:48,400 products can be positive or negative we 176 00:06:46,399 --> 00:06:50,000 want to make them all positive and 177 00:06:48,399 --> 00:06:52,638 normalize them so that they nicely add 178 00:06:50,000 --> 00:06:54,879 up to one. So we then exponentiate them 179 00:06:52,639 --> 00:06:57,519 and then divide with the total, right? 180 00:06:54,879 --> 00:06:59,038 Which is basically soft max. And when 181 00:06:57,519 --> 00:07:01,198 you do that, you have nice fractions 182 00:06:59,038 --> 00:07:03,519 that add up to one. And then we said, 183 00:07:01,199 --> 00:07:07,199 well, the contextual embedding for W6 is 184 00:07:03,519 --> 00:07:10,079 just all these weights S1, S2 all the 185 00:07:07,199 --> 00:07:12,960 way to S6 multiplied by the original W's 186 00:07:10,079 --> 00:07:14,879 and then you get the context for W6. So 187 00:07:12,959 --> 00:07:19,198 this is the basic logic we covered last 188 00:07:14,879 --> 00:07:21,839 time. Now it is obviously the case that 189 00:07:19,199 --> 00:07:23,598 we explained it only for one word but we 190 00:07:21,839 --> 00:07:25,839 have to do the same exact operation for 191 00:07:23,598 --> 00:07:28,719 every one of the other words too so that 192 00:07:25,839 --> 00:07:30,719 we could calculate W5 hat, W4 hat, W3 193 00:07:28,720 --> 00:07:32,240 hat and so on and so forth right so 194 00:07:30,720 --> 00:07:34,479 there's a lot of computations that are 195 00:07:32,240 --> 00:07:36,720 going on and they all look kind of 196 00:07:34,478 --> 00:07:38,240 similar where you got to do a bunch of 197 00:07:36,720 --> 00:07:39,759 dot products you got to like you know do 198 00:07:38,240 --> 00:07:42,000 some soft maxing on it and stuff like 199 00:07:39,759 --> 00:07:45,120 that so the natural question is is there 200 00:07:42,000 --> 00:07:46,959 a way to organize it very efficiently 201 00:07:45,120 --> 00:07:48,079 And the short answer is yes. In fact, if 202 00:07:46,959 --> 00:07:50,318 you could not do that, there wouldn't be 203 00:07:48,079 --> 00:07:52,079 any transformer revolution. Okay, 204 00:07:50,319 --> 00:07:53,439 because there is that ability to package 205 00:07:52,079 --> 00:07:55,758 it up into a very interesting and 206 00:07:53,439 --> 00:07:58,478 efficient operation that allows you to 207 00:07:55,759 --> 00:08:02,000 put the whole thing on GPUs. 208 00:07:58,478 --> 00:08:04,478 Okay, so now I'm going to switch to iPad 209 00:08:02,000 --> 00:08:06,879 uh and give you some iPad scribblings of 210 00:08:04,478 --> 00:08:08,399 mine which were concocted last night 211 00:08:06,879 --> 00:08:10,319 because I was very unhappy with the 212 00:08:08,399 --> 00:08:14,799 slides that follow. So, we're going to 213 00:08:10,319 --> 00:08:16,560 do iPad. Okay. U All right. So if it 214 00:08:14,800 --> 00:08:17,919 works, you folks are lucky. If it 215 00:08:16,560 --> 00:08:21,079 doesn't work, last year's huddle class 216 00:08:17,918 --> 00:08:21,079 is luckier. 217 00:08:21,360 --> 00:08:29,639 So let's shift to that. 218 00:08:24,240 --> 00:08:29,639 All right. So we're going to go here. 219 00:08:31,199 --> 00:08:37,158 So let's assume we have a simple thing 220 00:08:32,799 --> 00:08:37,158 like uh oops. 221 00:08:37,679 --> 00:08:41,359 Okay, instead of you know train left the 222 00:08:40,080 --> 00:08:42,639 station which is a long sentence, let's 223 00:08:41,360 --> 00:08:45,759 just say you have a simple sentence like 224 00:08:42,639 --> 00:08:47,439 I love hodddle. Okay, and so I love 225 00:08:45,759 --> 00:08:50,639 hodddle is what you have and then you 226 00:08:47,440 --> 00:08:53,760 have these standalone embeddings W1 W2 227 00:08:50,639 --> 00:08:55,838 W3. Okay, so it comes into the self 228 00:08:53,759 --> 00:08:58,639 attention layer and let's assume that 229 00:08:55,839 --> 00:09:00,959 these W1's, W2, W3, they're already 230 00:08:58,639 --> 00:09:02,399 positionally encoded, right? We have 231 00:09:00,958 --> 00:09:03,919 already added up the position encoding, 232 00:09:02,399 --> 00:09:05,039 all that stuff also. It's all behind us. 233 00:09:03,919 --> 00:09:08,079 That all happens outside the 234 00:09:05,039 --> 00:09:10,480 transformer. So you you you get it here. 235 00:09:08,080 --> 00:09:13,200 Now what you do is you actually make 236 00:09:10,480 --> 00:09:15,200 three copies of this thing. 237 00:09:13,200 --> 00:09:18,000 Okay? And let's call this whole thing as 238 00:09:15,200 --> 00:09:20,640 just X. Okay? I'm just giving it the 239 00:09:18,000 --> 00:09:23,360 name X. It's a matrix of these three 240 00:09:20,639 --> 00:09:25,199 vectors. And so the first copy goes up 241 00:09:23,360 --> 00:09:26,720 here, the second copy goes straight, and 242 00:09:25,200 --> 00:09:29,360 the third copy goes down. And don't 243 00:09:26,720 --> 00:09:31,600 worry about the third copy just yet. So 244 00:09:29,360 --> 00:09:33,680 if you look at the the first two copies, 245 00:09:31,600 --> 00:09:36,320 here is the key thing to focus on. Okay, 246 00:09:33,679 --> 00:09:37,759 this whole thing here. Remember that we 247 00:09:36,320 --> 00:09:40,240 want to calculate dotproducts between 248 00:09:37,759 --> 00:09:41,679 all these vectors. And basically we want 249 00:09:40,240 --> 00:09:44,799 to calculate the dot product of every 250 00:09:41,679 --> 00:09:46,319 pair of vectors, every pair of words. 251 00:09:44,799 --> 00:09:47,919 The whole point of self attention is 252 00:09:46,320 --> 00:09:49,440 that every pair of words we figure out 253 00:09:47,919 --> 00:09:50,639 how attracted or related they are. 254 00:09:49,440 --> 00:09:53,040 Right? Which means that we have to 255 00:09:50,639 --> 00:09:55,439 calculate all pairs of dot products. And 256 00:09:53,039 --> 00:09:58,159 so what you do is you take this vector 257 00:09:55,440 --> 00:10:00,880 right there W1 WW3. You take this other 258 00:09:58,159 --> 00:10:03,759 copy that went up. Okay? And then you 259 00:10:00,879 --> 00:10:05,278 transpose it. So when you transpose it, 260 00:10:03,759 --> 00:10:06,958 it all becomes nice and vertical like 261 00:10:05,278 --> 00:10:08,720 that. 262 00:10:06,958 --> 00:10:09,679 Right? All the vectors come in came like 263 00:10:08,720 --> 00:10:12,399 this. When you transfer, it becomes 264 00:10:09,679 --> 00:10:15,439 vertical. And now what you do is you 265 00:10:12,399 --> 00:10:19,839 take each one you take W1 and then you 266 00:10:15,440 --> 00:10:22,240 multiply it by W1. Here you take W1 W2 267 00:10:19,839 --> 00:10:23,760 W1 W3. You calculate all those dot 268 00:10:22,240 --> 00:10:27,120 products like that. And when you do that 269 00:10:23,759 --> 00:10:29,439 you have these nice cells where every 270 00:10:27,120 --> 00:10:31,919 pair of words their dot products have 271 00:10:29,440 --> 00:10:34,079 been calculated in this grid. Okay. And 272 00:10:31,919 --> 00:10:36,078 the key thing to see here and folks with 273 00:10:34,078 --> 00:10:38,559 a matrix algebra background will see 274 00:10:36,078 --> 00:10:40,319 this immediately. All we are doing is we 275 00:10:38,559 --> 00:10:42,399 are taking this x which is the matrix 276 00:10:40,320 --> 00:10:44,800 that came in 277 00:10:42,399 --> 00:10:46,480 and then xrpose which is the matrix that 278 00:10:44,799 --> 00:10:48,559 we went sent up and then brought back 279 00:10:46,480 --> 00:10:50,639 down. We are basically doing a matrix 280 00:10:48,559 --> 00:10:53,759 multiplication of x * xrpose. That's all 281 00:10:50,639 --> 00:10:57,600 we doing. And when we do that we're 282 00:10:53,759 --> 00:10:59,439 getting this nice uh grid of where in 283 00:10:57,600 --> 00:11:01,120 which every pair of words their dot 284 00:10:59,440 --> 00:11:03,600 products have been calculated for you 285 00:11:01,120 --> 00:11:05,440 with one matrix multiplication. Boom. 286 00:11:03,600 --> 00:11:07,200 Done. Okay. Okay, so if you have three 287 00:11:05,440 --> 00:11:11,200 words, there are nine multiplications, 288 00:11:07,200 --> 00:11:13,920 right? So if you have a million words, 289 00:11:11,200 --> 00:11:15,680 that's a lot of multiplications, right? 290 00:11:13,919 --> 00:11:18,078 One trillion multiplications on the 291 00:11:15,679 --> 00:11:21,199 order of all trillion. And the reason to 292 00:11:18,078 --> 00:11:23,039 say order is because you know W1 * W3 is 293 00:11:21,200 --> 00:11:25,680 the same as W3 * W1. So there's some 294 00:11:23,039 --> 00:11:27,360 duplication here. So you get this grid, 295 00:11:25,679 --> 00:11:29,838 okay, in one shot is one multi 296 00:11:27,360 --> 00:11:31,278 multiplication. And then we because each 297 00:11:29,839 --> 00:11:32,800 of these numbers is just a dot product 298 00:11:31,278 --> 00:11:34,559 which can be negative or positive, we 299 00:11:32,799 --> 00:11:36,240 need to softmax it. 300 00:11:34,559 --> 00:11:38,399 And so what we do is we take all these 301 00:11:36,240 --> 00:11:40,240 numbers and we put it into a softmax 302 00:11:38,399 --> 00:11:41,759 function where for each row it 303 00:11:40,240 --> 00:11:44,480 calculates a soft max. And what do I 304 00:11:41,759 --> 00:11:46,559 mean by that? It takes each number here 305 00:11:44,480 --> 00:11:47,839 does e raised to the top e ra to the 306 00:11:46,559 --> 00:11:49,919 number. It does it for each of these 307 00:11:47,839 --> 00:11:51,920 numbers and then divides by the sum of 308 00:11:49,919 --> 00:11:54,159 those numbers for each row. And when you 309 00:11:51,919 --> 00:11:56,639 do that okay you can think of this 310 00:11:54,159 --> 00:11:59,439 operation as soft max applied to x * 311 00:11:56,639 --> 00:12:01,039 xrpose you get this nice little table of 312 00:11:59,440 --> 00:12:02,880 numbers. 313 00:12:01,039 --> 00:12:06,240 This table of numbers basically says 314 00:12:02,879 --> 00:12:08,799 that for the first word right W1 for the 315 00:12:06,240 --> 00:12:11,519 first word take 0.1 of the of the first 316 00:12:08,799 --> 00:12:14,240 one 7 of the second.3 of the 2 of the 317 00:12:11,519 --> 00:12:17,200 third and add them up. We do a weighted 318 00:12:14,240 --> 00:12:20,720 average. So we have this table here. We 319 00:12:17,200 --> 00:12:24,000 have now the third copy shows up here. 320 00:12:20,720 --> 00:12:25,200 Okay is right there. So we do this times 321 00:12:24,000 --> 00:12:27,200 that which is just a matrix 322 00:12:25,200 --> 00:12:29,040 multiplication again. And when we do 323 00:12:27,200 --> 00:12:31,519 that we get the final contextual 324 00:12:29,039 --> 00:12:34,559 embeddings. So this for example is just 325 00:12:31,519 --> 00:12:36,399 0.1 * w12 326 00:12:34,559 --> 00:12:40,078 * w2 327 00:12:36,399 --> 00:12:41,600 point sorry 7 * w2 and then2 * w3 right 328 00:12:40,078 --> 00:12:44,399 there. And you can see the same logic 329 00:12:41,600 --> 00:12:46,480 here as well. Okay. And you can read it 330 00:12:44,399 --> 00:12:47,679 later on. I will post this thing uh to 331 00:12:46,480 --> 00:12:50,399 make sure you understand exactly how it 332 00:12:47,679 --> 00:12:53,759 flowed. But the larger point I want you 333 00:12:50,399 --> 00:12:55,278 to focus on is that the entire sol self 334 00:12:53,759 --> 00:12:58,159 attention operation we just looked at 335 00:12:55,278 --> 00:13:01,919 here basically is this this beautifully 336 00:12:58,159 --> 00:13:04,480 little compact matrix formula. 337 00:13:01,919 --> 00:13:06,240 Okay X comes in you do XRpose you do a 338 00:13:04,480 --> 00:13:07,519 matrix multiplication you do a softmax 339 00:13:06,240 --> 00:13:10,320 on top of it and then multiply by X 340 00:13:07,519 --> 00:13:12,799 again and boom you're done. 341 00:13:10,320 --> 00:13:15,200 So that is the magic of taking the 342 00:13:12,799 --> 00:13:17,199 transformer stack and representing it 343 00:13:15,200 --> 00:13:20,079 using matrix operations because then 344 00:13:17,200 --> 00:13:22,560 lightning fast on GPUs. 345 00:13:20,078 --> 00:13:24,638 Okay. All right. 346 00:13:22,559 --> 00:13:27,119 That was the warm-up. 347 00:13:24,639 --> 00:13:31,278 Now let's crank it up a notch. 348 00:13:27,120 --> 00:13:34,839 So recall that in the last class um I 349 00:13:31,278 --> 00:13:34,838 talked about the fact 350 00:13:35,519 --> 00:13:39,600 the self attention operation the W's are 351 00:13:38,000 --> 00:13:41,839 coming in and we're doing all this stuff 352 00:13:39,600 --> 00:13:44,399 with the W's right and then we're 353 00:13:41,839 --> 00:13:46,639 getting some W hats out but there are no 354 00:13:44,399 --> 00:13:48,958 parameters 355 00:13:46,639 --> 00:13:51,360 there's nothing to be learned inside the 356 00:13:48,958 --> 00:13:52,719 transformer self attention layer right 357 00:13:51,360 --> 00:13:54,639 there are no there are no weights there 358 00:13:52,720 --> 00:13:58,560 are no biases there are no coefficients 359 00:13:54,639 --> 00:14:00,879 so well okay What are we learning then? 360 00:13:58,559 --> 00:14:03,359 Right? So what we now do is we going to 361 00:14:00,879 --> 00:14:05,039 make the self attention layer tunable. 362 00:14:03,360 --> 00:14:07,440 We're going to inject some weights into 363 00:14:05,039 --> 00:14:09,120 it so that when we train it on an actual 364 00:14:07,440 --> 00:14:10,800 system, it'll the weights will keep 365 00:14:09,120 --> 00:14:12,240 changing to adapt itself to the 366 00:14:10,799 --> 00:14:15,599 particularities of whatever problem 367 00:14:12,240 --> 00:14:21,959 you're working on. Right? So that takes 368 00:14:15,600 --> 00:14:21,959 us to the tunable self attention layer. 369 00:14:22,720 --> 00:14:28,399 Okay? Tunable self attention layer. So 370 00:14:25,519 --> 00:14:29,759 this is the key thing to keep in mind. U 371 00:14:28,399 --> 00:14:33,639 any questions on this before I continue 372 00:14:29,759 --> 00:14:33,639 with the tunability thing. 373 00:14:34,639 --> 00:14:39,839 Okay. 374 00:14:37,120 --> 00:14:41,839 Is this picture working out by the way? 375 00:14:39,839 --> 00:14:44,000 Okay. 376 00:14:41,839 --> 00:14:46,160 Uh all right. 377 00:14:44,000 --> 00:14:48,799 So what we now do is we have the same 378 00:14:46,159 --> 00:14:51,120 exact logic as before where we have this 379 00:14:48,799 --> 00:14:53,519 thing that comes in. Okay. We have this 380 00:14:51,120 --> 00:14:55,360 input that comes in the same we call it 381 00:14:53,519 --> 00:14:58,399 X again. this whole this matrix of 382 00:14:55,360 --> 00:15:01,120 embeddings and then before we just send 383 00:14:58,399 --> 00:15:02,720 three copies instead of doing that what 384 00:15:01,120 --> 00:15:04,879 we're going to do is we'll take each 385 00:15:02,720 --> 00:15:07,519 copy X and then we will actually 386 00:15:04,879 --> 00:15:09,120 multiply it by a matrix 387 00:15:07,519 --> 00:15:10,720 okay this matrix is called the key 388 00:15:09,120 --> 00:15:14,078 matrix 389 00:15:10,720 --> 00:15:16,000 okay and this matrix this matrix of 390 00:15:14,078 --> 00:15:18,319 numbers are weights that will be learned 391 00:15:16,000 --> 00:15:20,399 by Brack prop 392 00:15:18,320 --> 00:15:23,199 so basically what we're saying is that 393 00:15:20,399 --> 00:15:25,759 when this thing comes in let's see if 394 00:15:23,198 --> 00:15:28,399 there's a way to transform this X into 395 00:15:25,759 --> 00:15:30,639 some other set of embeddings which may 396 00:15:28,399 --> 00:15:32,159 be useful for your task. We don't know 397 00:15:30,639 --> 00:15:34,320 if they're going to be useful, but 398 00:15:32,159 --> 00:15:36,399 surely giving it a bit more ability to 399 00:15:34,320 --> 00:15:39,199 have weights which can be learned means 400 00:15:36,399 --> 00:15:41,600 that it giving it more expressive power, 401 00:15:39,198 --> 00:15:42,799 more modeling capacity. And whether it 402 00:15:41,600 --> 00:15:44,159 actually uses the capacity will depend 403 00:15:42,799 --> 00:15:46,479 on how much data you have and how well 404 00:15:44,159 --> 00:15:48,879 you train it. And maybe if it's not 405 00:15:46,480 --> 00:15:50,800 useful, it won't use it. In what I mean 406 00:15:48,879 --> 00:15:52,799 is if transforming X actually doesn't 407 00:15:50,799 --> 00:15:55,679 really help at all, then this matrix A 408 00:15:52,799 --> 00:15:57,359 is going to be what? 409 00:15:55,679 --> 00:15:59,120 it's going to be the identity matrix 410 00:15:57,360 --> 00:16:01,278 because you take basically one and 411 00:15:59,120 --> 00:16:03,120 multiply by X you'll get one X again. So 412 00:16:01,278 --> 00:16:05,039 in the worst case maybe it just says I 413 00:16:03,120 --> 00:16:07,519 have nothing to learn here but maybe 414 00:16:05,039 --> 00:16:09,278 there is something you can learn. So so 415 00:16:07,519 --> 00:16:12,560 that's what we do. So we multiplied by 416 00:16:09,278 --> 00:16:14,480 this matrix A K and then we come up with 417 00:16:12,559 --> 00:16:16,239 the same you know some embeddings 418 00:16:14,480 --> 00:16:18,240 transformed embeddings and we call these 419 00:16:16,240 --> 00:16:22,000 things K 420 00:16:18,240 --> 00:16:24,079 okay K. Now this KQV as you will see has 421 00:16:22,000 --> 00:16:26,480 its origins in the in this field of 422 00:16:24,078 --> 00:16:28,159 information retrieval but I personally 423 00:16:26,480 --> 00:16:30,639 find that that interpretation is not 424 00:16:28,159 --> 00:16:32,000 super helpful because transformers are 425 00:16:30,639 --> 00:16:33,519 used for lots of applications outside 426 00:16:32,000 --> 00:16:35,759 information retrieval. So I'm not going 427 00:16:33,519 --> 00:16:37,360 to go with that kind of interpretation. 428 00:16:35,759 --> 00:16:39,440 I'm going to go with interpretation of 429 00:16:37,360 --> 00:16:41,360 let's make each of these things tunable. 430 00:16:39,440 --> 00:16:42,800 Okay. And tunability means we need to 431 00:16:41,360 --> 00:16:46,240 give it weights. All right. So that's 432 00:16:42,799 --> 00:16:47,758 what we have here. Now the second copy 433 00:16:46,240 --> 00:16:48,720 we did this with the first copy. Well, 434 00:16:47,759 --> 00:16:50,159 let's do the same thing with the second 435 00:16:48,720 --> 00:16:51,519 copy. We'll take the second copy and 436 00:16:50,159 --> 00:16:53,278 multiply it by some other matrix called 437 00:16:51,519 --> 00:16:54,720 AQ. 438 00:16:53,278 --> 00:16:57,439 And when we are done with that, we get 439 00:16:54,720 --> 00:17:00,320 these embeddings. And we will call these 440 00:16:57,440 --> 00:17:02,720 embeddings as Q. 441 00:17:00,320 --> 00:17:05,038 Okay. Now, just like before, we will 442 00:17:02,720 --> 00:17:07,120 take this this thing here and we'll 443 00:17:05,038 --> 00:17:08,400 transpose it. 444 00:17:07,119 --> 00:17:11,279 So, it all becomes nice and vertical 445 00:17:08,400 --> 00:17:12,319 like that. And then we'll do exactly the 446 00:17:11,279 --> 00:17:14,078 same as before. We'll calculate all 447 00:17:12,318 --> 00:17:16,720 these pair-wise dot productducts using 448 00:17:14,078 --> 00:17:20,159 one one shot one matrix multiplication. 449 00:17:16,720 --> 00:17:22,078 And because we are calling this Q and we 450 00:17:20,160 --> 00:17:26,000 are calling this whole thing as K. This 451 00:17:22,078 --> 00:17:29,038 thing just becomes Q * KT. 452 00:17:26,000 --> 00:17:31,919 Okay. At the end of it you come up with 453 00:17:29,038 --> 00:17:33,359 a grid of numbers just like before. 454 00:17:31,919 --> 00:17:35,120 Okay. And these numbers could be 455 00:17:33,359 --> 00:17:36,399 negative or positive. So we need to do 456 00:17:35,119 --> 00:17:38,079 the softmax on them to make sure they 457 00:17:36,400 --> 00:17:42,160 are well behaved fractions that add up 458 00:17:38,079 --> 00:17:44,879 to one. So we take this Q KT business 459 00:17:42,160 --> 00:17:48,320 and then we do we just run a we put it 460 00:17:44,880 --> 00:17:50,720 through a softmax function for each row 461 00:17:48,319 --> 00:17:52,639 and when we do that we we'll get 462 00:17:50,720 --> 00:17:54,160 basically the the like a table like the 463 00:17:52,640 --> 00:17:55,600 ones we saw before by the way the 464 00:17:54,160 --> 00:17:57,919 numbers here are the same just because I 465 00:17:55,599 --> 00:17:59,439 duplicated it because I'm lazy in 466 00:17:57,919 --> 00:18:00,480 reality given it has gone through all 467 00:17:59,440 --> 00:18:03,120 these transformations the numbers are 468 00:18:00,480 --> 00:18:05,440 not going to be the same right uh you 469 00:18:03,119 --> 00:18:08,719 have these numbers and then you take the 470 00:18:05,440 --> 00:18:10,080 final copy which is x * av Right? Each 471 00:18:08,720 --> 00:18:11,919 copy is getting multiplied by its own 472 00:18:10,079 --> 00:18:14,319 matrix. Right? And this copy is being 473 00:18:11,919 --> 00:18:19,440 multiplied by AV. And let's call this X 474 00:18:14,319 --> 00:18:21,519 A. Okay? Which is here as just V. 475 00:18:19,440 --> 00:18:24,640 And so what you have here is this soft 476 00:18:21,519 --> 00:18:26,319 max QT * V is exactly the same kind of 477 00:18:24,640 --> 00:18:28,080 dot product as we saw before matrix 478 00:18:26,319 --> 00:18:30,000 multiplication. So we have these 479 00:18:28,079 --> 00:18:32,240 contextual embeddings and that's what's 480 00:18:30,000 --> 00:18:34,798 coming out of the of the transformer 481 00:18:32,240 --> 00:18:36,960 block. So now the whole thing we did 482 00:18:34,798 --> 00:18:42,558 here the whole thing can be represented 483 00:18:36,960 --> 00:18:47,038 as soft max of Q KT * V. Okay. So if we 484 00:18:42,558 --> 00:18:49,200 zoom in a bit. Come on. Okay. 485 00:18:47,038 --> 00:18:52,319 Okay. 486 00:18:49,200 --> 00:18:55,440 So X came in. 487 00:18:52,319 --> 00:18:59,519 Three tracks went here. The first track 488 00:18:55,440 --> 00:19:01,360 X * A K X * AQ X * A V. And this thing 489 00:18:59,519 --> 00:19:03,918 is called K. This thing is called Q. 490 00:19:01,359 --> 00:19:06,079 This thing is called V. And then we do 491 00:19:03,919 --> 00:19:08,080 the same transpose as before. We do the 492 00:19:06,079 --> 00:19:09,839 dotproduct thing to calculate the 493 00:19:08,079 --> 00:19:12,319 pair-wise dot products for everything 494 00:19:09,839 --> 00:19:15,119 which is just Q KT. We run it through a 495 00:19:12,319 --> 00:19:16,798 soft max. We get soft max of Q KT. We 496 00:19:15,119 --> 00:19:18,879 multiply it by one to do the final 497 00:19:16,798 --> 00:19:22,079 waiting and then boom the output comes 498 00:19:18,880 --> 00:19:24,160 and that's this function. That's it. 499 00:19:22,079 --> 00:19:27,279 Okay. So what we have done is we have 500 00:19:24,160 --> 00:19:31,200 introduced three matrices learnable 501 00:19:27,279 --> 00:19:34,319 matrices into the self attention layer. 502 00:19:31,200 --> 00:19:35,679 Okay. Now, 503 00:19:34,319 --> 00:19:37,439 okay. Let me just stop there for a sec. 504 00:19:35,679 --> 00:19:39,668 Questions. 505 00:19:37,440 --> 00:19:39,840 Yeah. 506 00:19:39,667 --> 00:19:43,119 [clears throat] 507 00:19:39,839 --> 00:19:44,159 >> Is there a relationship between AK, AQ, 508 00:19:43,119 --> 00:19:47,599 and A 509 00:19:44,160 --> 00:19:48,558 >> independent independent matrices? 510 00:19:47,599 --> 00:19:49,279 >> Yes. 511 00:19:48,558 --> 00:19:50,558 >> Like we have 512 00:19:49,279 --> 00:19:52,480 >> could you use the microphone please? 513 00:19:50,558 --> 00:19:55,038 >> Here we have three set of parameters K, 514 00:19:52,480 --> 00:19:58,240 Q and P. If there are let's say if there 515 00:19:55,038 --> 00:19:59,839 were 100 the total length was let's say 516 00:19:58,240 --> 00:20:02,558 the number of total totals were let's 517 00:19:59,839 --> 00:20:04,959 say 50. So you would have uh 50 for a 518 00:20:02,558 --> 00:20:07,678 set of parameters like you'll have to 519 00:20:04,960 --> 00:20:10,079 >> so if you have a 50 if the dimension is 520 00:20:07,679 --> 00:20:13,038 50 long what is coming in the W's are 50 521 00:20:10,079 --> 00:20:15,678 long then the key the what comes out of 522 00:20:13,038 --> 00:20:20,599 it if you want it to be 50 as well so 523 00:20:15,679 --> 00:20:20,600 this matrix needs to be 50 * 50 2500 524 00:20:22,960 --> 00:20:27,519 >> U Luna 525 00:20:24,798 --> 00:20:30,400 >> what are the different things the three 526 00:20:27,519 --> 00:20:30,798 the three matrices are trying to 527 00:20:30,400 --> 00:20:32,000 Sorry, 528 00:20:30,798 --> 00:20:33,679 >> what are the different things that the 529 00:20:32,000 --> 00:20:35,599 matrices are trying to learn? 530 00:20:33,679 --> 00:20:37,120 >> We don't know. All we are saying is that 531 00:20:35,599 --> 00:20:38,959 we have a self attention layer which can 532 00:20:37,119 --> 00:20:40,959 pay attention to every pair of words. 533 00:20:38,960 --> 00:20:43,120 But we need to give it some ways to 534 00:20:40,960 --> 00:20:45,759 transform what is coming in into 535 00:20:43,119 --> 00:20:48,000 potentially useful things. Right? As to 536 00:20:45,759 --> 00:20:49,679 their actual usefulness, we'll have to 537 00:20:48,000 --> 00:20:51,200 figure out if if it actually helps or 538 00:20:49,679 --> 00:20:52,320 not. And of course, as you know, the the 539 00:20:51,200 --> 00:20:54,240 punch line is that yeah, it helps 540 00:20:52,319 --> 00:20:55,279 massively. That's why we do it. In 541 00:20:54,240 --> 00:20:57,519 general, what you will find in the deep 542 00:20:55,279 --> 00:20:58,960 learning literature is that whenever you 543 00:20:57,519 --> 00:21:01,119 want to increase the capacity, the 544 00:20:58,960 --> 00:21:03,600 modeling capacity of a particular model, 545 00:21:01,119 --> 00:21:05,839 you just take a small piece and inject a 546 00:21:03,599 --> 00:21:07,599 little matrix multiplication into it. 547 00:21:05,839 --> 00:21:08,959 You take a vector that's showing up in 548 00:21:07,599 --> 00:21:10,879 the middle and then you make it run 549 00:21:08,960 --> 00:21:13,038 through a matrix to get another vector 550 00:21:10,880 --> 00:21:14,559 and then further after you run it 551 00:21:13,038 --> 00:21:17,119 through a matrix, you run it through a 552 00:21:14,558 --> 00:21:19,519 little ReLU as well. Even better. So 553 00:21:17,119 --> 00:21:22,158 that's how you inject modeling capacity 554 00:21:19,519 --> 00:21:23,359 into the middle of these networks. Okay? 555 00:21:22,159 --> 00:21:26,640 And that's what these people are doing 556 00:21:23,359 --> 00:21:29,519 here. Yeah. 557 00:21:26,640 --> 00:21:31,360 >> In the last step, you had the matrix V. 558 00:21:29,519 --> 00:21:33,038 So on the previous example, you had used 559 00:21:31,359 --> 00:21:35,359 the original matrix X. So could you just 560 00:21:33,038 --> 00:21:36,079 say for why is it not using X? What does 561 00:21:35,359 --> 00:21:38,479 that mean? 562 00:21:36,079 --> 00:21:40,319 >> So what we're saying is that the in the 563 00:21:38,480 --> 00:21:42,480 initial version we had three copies and 564 00:21:40,319 --> 00:21:44,000 we treated them all identical. Now we 565 00:21:42,480 --> 00:21:45,599 said well there are are there ways to 566 00:21:44,000 --> 00:21:47,519 transform each copy into some other 567 00:21:45,599 --> 00:21:48,959 representation which could be useful. So 568 00:21:47,519 --> 00:21:51,519 we may as well use three different 569 00:21:48,960 --> 00:21:52,960 matrices for it. Why stop with two? 570 00:21:51,519 --> 00:21:54,558 There are three opportunities to make 571 00:21:52,960 --> 00:21:56,558 them more expressive. We'll use all of 572 00:21:54,558 --> 00:21:59,558 them. 573 00:21:56,558 --> 00:21:59,558 >> Yeah. 574 00:21:59,759 --> 00:22:03,919 >> You mentioned that these are kind of 575 00:22:02,240 --> 00:22:05,359 you're tuning it. You're kind of 576 00:22:03,919 --> 00:22:06,960 fine-tuning it. Is there any risk? 577 00:22:05,359 --> 00:22:09,199 >> We're not fine-tuning it. Uh just to be 578 00:22:06,960 --> 00:22:10,880 clear on the on the vocabulary here. So 579 00:22:09,200 --> 00:22:12,880 we have added more weights to make them 580 00:22:10,880 --> 00:22:16,320 tunable. What that means is that we when 581 00:22:12,880 --> 00:22:17,760 we finally train this entire model, 582 00:22:16,319 --> 00:22:20,240 remember all the weights are going to be 583 00:22:17,759 --> 00:22:21,839 updated using back propagation, right? 584 00:22:20,240 --> 00:22:23,839 In particular, these matrices will also 585 00:22:21,839 --> 00:22:26,319 get updated using back propagation. 586 00:22:23,839 --> 00:22:27,678 >> So there's no risk of is there a risk of 587 00:22:26,319 --> 00:22:29,759 >> there's always the risk of overfitting 588 00:22:27,679 --> 00:22:31,038 when you add more parameters to a model 589 00:22:29,759 --> 00:22:34,079 >> which means that you have to look at the 590 00:22:31,038 --> 00:22:36,400 validation set and all that good stuff. 591 00:22:34,079 --> 00:22:39,038 We are basically adding more parameters 592 00:22:36,400 --> 00:22:40,720 in a very interesting way because we 593 00:22:39,038 --> 00:22:41,919 want to add more capacity to the self 594 00:22:40,720 --> 00:22:43,440 attention layer. We want to give it a 595 00:22:41,919 --> 00:22:45,600 more of an ability to learn things from 596 00:22:43,440 --> 00:22:48,080 the data. Before it could not learn 597 00:22:45,599 --> 00:22:51,119 anything. It could only do dot products. 598 00:22:48,079 --> 00:22:52,399 So we we want to solve that problem. 599 00:22:51,119 --> 00:22:56,759 All right, I'm going to continue and 600 00:22:52,400 --> 00:22:56,759 we'll come back to this. Okay. Um 601 00:22:57,359 --> 00:23:01,678 so uh all right, let's just just for 602 00:22:59,359 --> 00:23:03,119 fun, I'm going to do this. Um the the 603 00:23:01,679 --> 00:23:05,519 original paper is called attention is 604 00:23:03,119 --> 00:23:07,599 all you need. This is a transformer 605 00:23:05,519 --> 00:23:11,519 paper. 606 00:23:07,599 --> 00:23:14,399 You folks should read it at some point. 607 00:23:11,519 --> 00:23:17,400 Just want to show you something. 608 00:23:14,400 --> 00:23:17,400 Uh 609 00:23:20,000 --> 00:23:26,000 You see that? So that is the famous 610 00:23:22,319 --> 00:23:29,038 transformer formula. Okay. And the only 611 00:23:26,000 --> 00:23:31,440 thing we ignored is this root of DK 612 00:23:29,038 --> 00:23:33,119 business in the back under it. I 613 00:23:31,440 --> 00:23:35,759 wouldn't worry about it. The reason they 614 00:23:33,119 --> 00:23:37,199 have it is because these soft maxes when 615 00:23:35,759 --> 00:23:39,679 you have lots of numbers and some 616 00:23:37,200 --> 00:23:41,120 numbers really really big what's going 617 00:23:39,679 --> 00:23:43,679 to happen is that all the other numbers 618 00:23:41,119 --> 00:23:45,519 are going to get squashed to zero. Okay. 619 00:23:43,679 --> 00:23:47,600 And so to make sure the gradient flows 620 00:23:45,519 --> 00:23:49,279 properly, they just divide it by a 621 00:23:47,599 --> 00:23:51,918 particular number to make sure no number 622 00:23:49,279 --> 00:23:53,599 is too big. Okay, that's a small 623 00:23:51,919 --> 00:23:54,880 technical important but bit of a 624 00:23:53,599 --> 00:23:57,359 technical detail which is why I ignored 625 00:23:54,880 --> 00:23:59,760 it in my iPad. But the rest of it you 626 00:23:57,359 --> 00:24:03,918 can see this is exactly the formula we 627 00:23:59,759 --> 00:24:05,759 derived qt * v softmax. 628 00:24:03,919 --> 00:24:08,159 Okay, so this is the famous transformer 629 00:24:05,759 --> 00:24:10,240 formula 630 00:24:08,159 --> 00:24:11,840 and congratulations now you understand 631 00:24:10,240 --> 00:24:14,720 it. 632 00:24:11,839 --> 00:24:17,199 You seem less than fully convinced. 633 00:24:14,720 --> 00:24:19,120 Okay. 634 00:24:17,200 --> 00:24:21,600 Yes. Hi iPad. 635 00:24:19,119 --> 00:24:24,079 Now I have a bunch of slides which I had 636 00:24:21,599 --> 00:24:25,678 but actually I'll come back to this. I 637 00:24:24,079 --> 00:24:27,359 had a bunch of other slides. This is 638 00:24:25,679 --> 00:24:28,880 from last year uh which actually 639 00:24:27,359 --> 00:24:30,000 explains what I did in the iPad in a 640 00:24:28,880 --> 00:24:32,240 very different way without using any 641 00:24:30,000 --> 00:24:34,240 matrices and so on. I was looking at it 642 00:24:32,240 --> 00:24:36,480 last evening and I was getting very 643 00:24:34,240 --> 00:24:38,000 annoyed by these slides for some reason 644 00:24:36,480 --> 00:24:40,480 because I felt that it wasn't really 645 00:24:38,000 --> 00:24:43,679 conveying the core matrix sort of the 646 00:24:40,480 --> 00:24:45,919 matrix uh the ability of using matrix 647 00:24:43,679 --> 00:24:47,840 algebra to to actually do this so 648 00:24:45,919 --> 00:24:49,278 efficiently and compactly which is why I 649 00:24:47,839 --> 00:24:51,599 decided to like handdraw this thing on 650 00:24:49,278 --> 00:24:53,119 the iPad. Okay, but you should read it 651 00:24:51,599 --> 00:24:55,439 afterwards to make sure that whatever 652 00:24:53,119 --> 00:24:56,798 you saw on the iPad actually matches 653 00:24:55,440 --> 00:24:58,880 this. Okay, because two different ways 654 00:24:56,798 --> 00:25:02,480 of understanding something always helps. 655 00:24:58,880 --> 00:25:05,360 Um okay so this what we have here now to 656 00:25:02,480 --> 00:25:07,360 just to recall 657 00:25:05,359 --> 00:25:08,798 the by making self attention tunable we 658 00:25:07,359 --> 00:25:10,319 get a very interesting benefit which is 659 00:25:08,798 --> 00:25:13,278 that when you have these different 660 00:25:10,319 --> 00:25:14,798 attention heads before 661 00:25:13,278 --> 00:25:16,798 you could have two attention heads but 662 00:25:14,798 --> 00:25:19,278 because there were no parameters inside 663 00:25:16,798 --> 00:25:21,440 their outputs would have been identical 664 00:25:19,278 --> 00:25:23,119 because the inputs are the same for both 665 00:25:21,440 --> 00:25:25,440 therefore the outputs would be identical 666 00:25:23,119 --> 00:25:28,319 but now by since each attention head 667 00:25:25,440 --> 00:25:29,120 will have its own aq 668 00:25:28,319 --> 00:25:32,000 matrix 669 00:25:29,119 --> 00:25:34,239 the outputs are going to be different. 670 00:25:32,000 --> 00:25:36,319 That's why it makes sense to do the 671 00:25:34,240 --> 00:25:37,759 tunability thing because that's what 672 00:25:36,319 --> 00:25:42,439 actually makes multiple attention it's 673 00:25:37,759 --> 00:25:42,440 actually useful. Um 674 00:25:43,038 --> 00:25:47,839 is is there actually any relationship 675 00:25:44,880 --> 00:25:49,520 between AK AQ and AV or is the A just 676 00:25:47,839 --> 00:25:51,439 for like a notation standpoint? 677 00:25:49,519 --> 00:25:54,720 >> Just notation. The thing is we want to 678 00:25:51,440 --> 00:25:56,320 use QV for the resulting matrix and so I 679 00:25:54,720 --> 00:25:58,480 had to find something else to use for 680 00:25:56,319 --> 00:25:59,839 the first one and I was like okay aqaq 681 00:25:58,480 --> 00:26:03,360 and we at MIT we do subscript super 682 00:25:59,839 --> 00:26:05,038 subcripts right so yeah 683 00:26:03,359 --> 00:26:07,678 >> what what is the the size of the 684 00:26:05,038 --> 00:26:08,319 matrices are there like square matrices 685 00:26:07,679 --> 00:26:10,400 or 686 00:26:08,319 --> 00:26:12,158 >> yeah so typically what happens is that 687 00:26:10,400 --> 00:26:14,240 um there's a whole bunch you can think 688 00:26:12,159 --> 00:26:15,919 of it as a hyperparameter in some ways 689 00:26:14,240 --> 00:26:17,200 um typically what people do in most 690 00:26:15,919 --> 00:26:19,038 implementations is that they will 691 00:26:17,200 --> 00:26:20,798 actually just preserve the size so if 692 00:26:19,038 --> 00:26:22,400 the incoming embedding is and they'll 693 00:26:20,798 --> 00:26:24,720 make sure the the thing coming out of 694 00:26:22,400 --> 00:26:27,519 thing is also 10. So you just do a 10x10 695 00:26:24,720 --> 00:26:31,038 matrix to transform it. Uh but the the 696 00:26:27,519 --> 00:26:32,639 the value v av matrix on the other hand 697 00:26:31,038 --> 00:26:35,599 there's a bit more technical stuff going 698 00:26:32,640 --> 00:26:37,200 on where it often tends to be smaller. 699 00:26:35,599 --> 00:26:39,839 Um so for example let's say that your 700 00:26:37,200 --> 00:26:42,240 incoming is 100 you do 100 to 100 for 701 00:26:39,839 --> 00:26:44,480 the key 100 to 100 for the query. But if 702 00:26:42,240 --> 00:26:47,440 you have say five attention heads, you 703 00:26:44,480 --> 00:26:48,798 may do 100 to 20 for the W's because 704 00:26:47,440 --> 00:26:51,600 ultimately all the V's are going to get 705 00:26:48,798 --> 00:26:53,918 concatenated into another 100 again. So 706 00:26:51,599 --> 00:26:55,278 I can tell you more offline but fun 707 00:26:53,919 --> 00:26:56,240 broadly speaking these things tend to 708 00:26:55,278 --> 00:26:58,720 get transformed. They don't they 709 00:26:56,240 --> 00:27:00,240 preserve the dimension 10 and 10 out. 710 00:26:58,720 --> 00:27:04,798 Yeah. 711 00:27:00,240 --> 00:27:06,640 >> So this uh aq uh these numbers are 712 00:27:04,798 --> 00:27:07,599 random when you start with it and then 713 00:27:06,640 --> 00:27:11,159 allow it to back. 714 00:27:07,599 --> 00:27:11,158 >> Exactly. Exactly. 715 00:27:11,440 --> 00:27:15,640 So all right um 716 00:27:17,359 --> 00:27:20,798 yeah so the values in these matrices are 717 00:27:19,359 --> 00:27:23,918 weights learned through optimization 718 00:27:20,798 --> 00:27:25,839 using SGD. Uh and then what that means 719 00:27:23,919 --> 00:27:27,679 is that 720 00:27:25,839 --> 00:27:29,599 each of these attention now has its own 721 00:27:27,679 --> 00:27:31,759 copy of these matrices. It has its own 722 00:27:29,599 --> 00:27:33,359 matrices and over the course of back 723 00:27:31,759 --> 00:27:36,319 propagation these matrices will look 724 00:27:33,359 --> 00:27:38,558 very different. Okay. So important each 725 00:27:36,319 --> 00:27:40,639 attention head will have its own mat set 726 00:27:38,558 --> 00:27:42,079 of three matrices. So if you have 10 727 00:27:40,640 --> 00:27:45,080 attention heads 30 matrices will be 728 00:27:42,079 --> 00:27:45,079 learned. 729 00:27:46,400 --> 00:27:50,798 So by the math it seems like it's 730 00:27:48,558 --> 00:27:52,399 creating essentially a relationship 731 00:27:50,798 --> 00:27:54,639 between all of the content being 732 00:27:52,400 --> 00:27:56,240 ingested and if you're creating if 733 00:27:54,640 --> 00:27:58,080 you're ingesting all the content for 734 00:27:56,240 --> 00:28:00,399 each attention head are there different 735 00:27:58,079 --> 00:28:01,839 categories of attention head type that 736 00:28:00,398 --> 00:28:03,199 you're trying to go after? 737 00:28:01,839 --> 00:28:04,798 >> Yeah. So basically what we're trying to 738 00:28:03,200 --> 00:28:07,120 do is to say a particular attention 739 00:28:04,798 --> 00:28:09,038 head. So in any particular sentence it 740 00:28:07,119 --> 00:28:10,558 may turn out to be the case that one 741 00:28:09,038 --> 00:28:12,240 pattern could be about the meanings of 742 00:28:10,558 --> 00:28:14,480 these words right like the word bank and 743 00:28:12,240 --> 00:28:15,679 what it means the word station train 744 00:28:14,480 --> 00:28:17,519 things like that. That's what really 745 00:28:15,679 --> 00:28:19,360 we've been talking about. But there is a 746 00:28:17,519 --> 00:28:21,200 whole other pattern to do with grammar 747 00:28:19,359 --> 00:28:23,759 and tense and things like that. There 748 00:28:21,200 --> 00:28:25,440 could be another one in terms of tone. 749 00:28:23,759 --> 00:28:26,960 All those things are very important. And 750 00:28:25,440 --> 00:28:28,640 a priority we don't know how many such 751 00:28:26,960 --> 00:28:30,079 patterns exist. Much like in a 752 00:28:28,640 --> 00:28:31,600 convolutional network, we don't when 753 00:28:30,079 --> 00:28:33,199 we're designing how many filters to 754 00:28:31,599 --> 00:28:34,798 have, we don't know how many kinds of 755 00:28:33,200 --> 00:28:36,798 little things we have to detect, you 756 00:28:34,798 --> 00:28:38,158 know, vertical line, horizontal line, 757 00:28:36,798 --> 00:28:39,679 semicircle, quarter circle, stuff like 758 00:28:38,159 --> 00:28:41,840 that. So, you just give it a lot of 759 00:28:39,679 --> 00:28:45,000 capacity so that it can learn whatever 760 00:28:41,839 --> 00:28:45,000 it wants. 761 00:28:45,038 --> 00:28:49,359 All right. So, um so that that is the 762 00:28:47,440 --> 00:28:51,840 transformer encoder. So, we have done 763 00:28:49,359 --> 00:28:53,278 one the first of the three complications 764 00:28:51,839 --> 00:28:56,720 needed to make it like industrial 765 00:28:53,278 --> 00:28:58,159 strength and legit. Uh the second thing 766 00:28:56,720 --> 00:29:02,720 we do is something called the residual 767 00:28:58,159 --> 00:29:05,120 connection. So what we do is that 768 00:29:02,720 --> 00:29:08,798 whatever comes out here right W1 through 769 00:29:05,119 --> 00:29:11,278 W6 goes in and comes out as W1 hat W2 770 00:29:08,798 --> 00:29:13,759 and so on and so forth right 771 00:29:11,278 --> 00:29:16,240 actually sorry what comes out here is 772 00:29:13,759 --> 00:29:18,720 the hats but what comes out here is some 773 00:29:16,240 --> 00:29:20,079 intermediate W's right that is what the 774 00:29:18,720 --> 00:29:22,399 selfident is going to give you some 775 00:29:20,079 --> 00:29:24,079 intermediate W's what we do is and 776 00:29:22,398 --> 00:29:26,479 because what's coming out here these 777 00:29:24,079 --> 00:29:28,720 vectors are the same length as what goes 778 00:29:26,480 --> 00:29:29,440 in we can just add them element by 779 00:29:28,720 --> 00:29:32,159 element 780 00:29:29,440 --> 00:29:35,120 So we take the input and we actually add 781 00:29:32,159 --> 00:29:37,679 it to what comes out. 782 00:29:35,119 --> 00:29:39,918 So why would we want to do that? Why 783 00:29:37,679 --> 00:29:41,600 would we want to you know go to a lot of 784 00:29:39,919 --> 00:29:43,520 trouble to process this thing and then 785 00:29:41,599 --> 00:29:45,759 when it comes out we like literally add 786 00:29:43,519 --> 00:29:49,879 up the original input? What's like what 787 00:29:45,759 --> 00:29:49,879 do you think the intuition is? 788 00:29:52,398 --> 00:29:57,918 So turns out, think of it this way. You 789 00:29:56,240 --> 00:30:00,000 have a bunch of inputs. You send it to a 790 00:29:57,919 --> 00:30:02,240 neural network. It transforms it and 791 00:30:00,000 --> 00:30:04,798 gives you something else. Right? At that 792 00:30:02,240 --> 00:30:06,159 point, you might be thinking, well, 793 00:30:04,798 --> 00:30:07,519 everything that go everything that 794 00:30:06,159 --> 00:30:10,240 happens in the network from that point 795 00:30:07,519 --> 00:30:12,558 onward can no longer see your original 796 00:30:10,240 --> 00:30:14,640 input. It can only work with the 797 00:30:12,558 --> 00:30:17,599 transformed input. Right? But what if 798 00:30:14,640 --> 00:30:20,080 your transformations are not great? 799 00:30:17,599 --> 00:30:22,798 So as an insurance policy what you can 800 00:30:20,079 --> 00:30:24,798 do is you can take the the transform 801 00:30:22,798 --> 00:30:27,839 stuff and you can take the original 802 00:30:24,798 --> 00:30:30,158 stuff and send both in. 803 00:30:27,839 --> 00:30:31,439 Right? And this whole thing is and you 804 00:30:30,159 --> 00:30:33,120 can Google it. It's called like a wide 805 00:30:31,440 --> 00:30:35,200 and deep network and things like that. 806 00:30:33,119 --> 00:30:37,278 But the whole point is that let's not 807 00:30:35,200 --> 00:30:39,440 lose the original input anywhere. Let's 808 00:30:37,278 --> 00:30:40,880 also send it along. But if you keep 809 00:30:39,440 --> 00:30:42,080 adding the original input to every 810 00:30:40,880 --> 00:30:43,440 intermediate layer, it's going to get 811 00:30:42,079 --> 00:30:44,720 longer and longer and longer and bigger, 812 00:30:43,440 --> 00:30:46,640 which you don't want because you want it 813 00:30:44,720 --> 00:30:49,360 all to be the same size. So the simplest 814 00:30:46,640 --> 00:30:50,960 alternative is to just add them up. You 815 00:30:49,359 --> 00:30:52,879 take the transform stuff and you add the 816 00:30:50,960 --> 00:30:54,960 original input. You get the same thing 817 00:30:52,880 --> 00:30:57,919 again. The the what came out what came 818 00:30:54,960 --> 00:31:00,240 in W1 was a 100 long vector and the 819 00:30:57,919 --> 00:31:02,240 transformed version is also 100 long. So 820 00:31:00,240 --> 00:31:04,000 just literally 100 100 add them up. 821 00:31:02,240 --> 00:31:06,319 That's it. You get another 100 long 822 00:31:04,000 --> 00:31:08,880 vector. So that is what's called a 823 00:31:06,319 --> 00:31:12,079 residual connection. Okay. And as it 824 00:31:08,880 --> 00:31:14,480 turns out, residual connections make it 825 00:31:12,079 --> 00:31:16,960 m improve the gradient flow during back 826 00:31:14,480 --> 00:31:18,960 propagation dramatically and that's why 827 00:31:16,960 --> 00:31:21,440 they are very heavily used. And in fact, 828 00:31:18,960 --> 00:31:24,319 RestNet, which we looked at for computer 829 00:31:21,440 --> 00:31:26,399 vision, it stands for residual net 830 00:31:24,319 --> 00:31:29,200 because it was the first network to 831 00:31:26,398 --> 00:31:30,719 actually figure this out. It's not this 832 00:31:29,200 --> 00:31:32,399 this is not just a transformer thing by 833 00:31:30,720 --> 00:31:35,278 the way. It's widely used in you know 834 00:31:32,398 --> 00:31:36,719 lots of new architectures. The notion of 835 00:31:35,278 --> 00:31:39,759 a residual connection that's what it 836 00:31:36,720 --> 00:31:42,720 means. Okay, so we do a residual 837 00:31:39,759 --> 00:31:44,000 connection and then we come to the final 838 00:31:42,720 --> 00:31:45,600 tweak which is called layer 839 00:31:44,000 --> 00:31:47,440 normalization. 840 00:31:45,599 --> 00:31:48,879 So once we add the residual connection, 841 00:31:47,440 --> 00:31:51,120 we are going to do something else here 842 00:31:48,880 --> 00:31:54,080 to these vectors before they continue 843 00:31:51,119 --> 00:31:57,759 flowing. And what layer normalation does 844 00:31:54,079 --> 00:31:59,599 is it basically says that 845 00:31:57,759 --> 00:32:00,798 I you will recall from the very 846 00:31:59,599 --> 00:32:02,639 beginning of the semester I've been 847 00:32:00,798 --> 00:32:04,480 saying that whatever comes into a neural 848 00:32:02,640 --> 00:32:05,840 network the inputs let's just really 849 00:32:04,480 --> 00:32:07,919 make sure that they are all in some sort 850 00:32:05,839 --> 00:32:10,558 of a narrow well- definfined range they 851 00:32:07,919 --> 00:32:12,960 can't be in a big range right so for 852 00:32:10,558 --> 00:32:15,038 pictures for images we divided every 853 00:32:12,960 --> 00:32:18,480 number by 255 so that every little pixel 854 00:32:15,038 --> 00:32:20,158 value is between zero and one okay for 855 00:32:18,480 --> 00:32:22,720 continuous things like the heart disease 856 00:32:20,159 --> 00:32:24,399 example we standardized by calculating 857 00:32:22,720 --> 00:32:26,000 the mean and the standard deviation and 858 00:32:24,398 --> 00:32:27,278 doing subtracting the mean and dividing 859 00:32:26,000 --> 00:32:28,880 by the standard deviation. So when you 860 00:32:27,278 --> 00:32:32,480 do that all the numbers are going to 861 00:32:28,880 --> 00:32:35,039 roughly be in the minus1 to +1 range. So 862 00:32:32,480 --> 00:32:36,960 in neural networks it's for backrop to 863 00:32:35,038 --> 00:32:39,839 work really well you have to make sure 864 00:32:36,960 --> 00:32:41,200 that no numbers get too big that all the 865 00:32:39,839 --> 00:32:43,199 numbers are always in some sort of a 866 00:32:41,200 --> 00:32:45,519 narrow range. So what layer 867 00:32:43,200 --> 00:32:48,240 normalization does is to say you know 868 00:32:45,519 --> 00:32:49,519 what whatever is coming out here I want 869 00:32:48,240 --> 00:32:51,519 to make sure none of these numbers are 870 00:32:49,519 --> 00:32:53,679 too big. I want to make sure they're all 871 00:32:51,519 --> 00:32:55,599 well behaved in a small range because if 872 00:32:53,679 --> 00:32:59,600 I don't do that back prop is not going 873 00:32:55,599 --> 00:33:01,759 to work very well and so 874 00:32:59,599 --> 00:33:04,480 is this what we do to ensure we don't 875 00:33:01,759 --> 00:33:06,558 problem of vanishing right 876 00:33:04,480 --> 00:33:07,839 >> so um so the there technically there are 877 00:33:06,558 --> 00:33:09,038 there could be two problems there's an 878 00:33:07,839 --> 00:33:10,959 exploding gradient and vanishing 879 00:33:09,038 --> 00:33:12,480 gradient both are bad this is a way to 880 00:33:10,960 --> 00:33:15,200 address it so you will find a whole 881 00:33:12,480 --> 00:33:17,038 bunch of dash normalization techniques 882 00:33:15,200 --> 00:33:19,120 layer normalization batch normalization 883 00:33:17,038 --> 00:33:21,200 and so on and so forth all these are 884 00:33:19,119 --> 00:33:22,798 methods to make that these numbers stay 885 00:33:21,200 --> 00:33:26,600 in a small range so it doesn't cause 886 00:33:22,798 --> 00:33:26,599 gradient issues later. 887 00:33:27,038 --> 00:33:32,879 All right. So in particular 888 00:33:30,159 --> 00:33:35,200 what we do is or what happens inside 889 00:33:32,880 --> 00:33:36,480 this layer layer normalization is we 890 00:33:35,200 --> 00:33:37,440 just calculate the mean and standard 891 00:33:36,480 --> 00:33:39,759 deviation of every one of these 892 00:33:37,440 --> 00:33:41,440 embeddings. Okay? Right? If you have 893 00:33:39,759 --> 00:33:42,480 let's say six embeddings here, we'll 894 00:33:41,440 --> 00:33:43,840 have six means and six standard 895 00:33:42,480 --> 00:33:46,000 deviations, right? For each one across 896 00:33:43,839 --> 00:33:48,319 the rows and then we standardize it. 897 00:33:46,000 --> 00:33:49,599 Meaning subtract the mean divide by the 898 00:33:48,319 --> 00:33:51,599 standard deviation. And when you do 899 00:33:49,599 --> 00:33:54,079 that, all these things are going to be 900 00:33:51,599 --> 00:33:55,678 nice and small. And then we do this a 901 00:33:54,079 --> 00:33:58,879 little other thing where we we have 902 00:33:55,679 --> 00:34:01,120 introduced two new parameters to rescale 903 00:33:58,880 --> 00:34:03,840 it and move it around a little bit just 904 00:34:01,119 --> 00:34:06,079 because adding more weights always helps 905 00:34:03,839 --> 00:34:07,918 make these things better. So we add them 906 00:34:06,079 --> 00:34:09,358 and this gets slightly complicated 907 00:34:07,919 --> 00:34:10,480 because of the way the dimensions work. 908 00:34:09,358 --> 00:34:13,598 So I'm not going to spend much time on 909 00:34:10,480 --> 00:34:15,358 it. Uh and then what comes out the other 910 00:34:13,599 --> 00:34:16,960 end is a very well- behaved set of 911 00:34:15,358 --> 00:34:18,960 numbers in a nice and small and narrow 912 00:34:16,960 --> 00:34:20,480 range. 913 00:34:18,960 --> 00:34:23,280 Okay, so this is called layer 914 00:34:20,480 --> 00:34:25,760 normalization. Um, you can see this link 915 00:34:23,280 --> 00:34:28,480 to understand it a bit better. Um, and 916 00:34:25,760 --> 00:34:30,639 we do that as well. So to put it all 917 00:34:28,480 --> 00:34:32,639 together, 918 00:34:30,639 --> 00:34:34,159 so this is a transformer encoder where 919 00:34:32,639 --> 00:34:36,559 we have this multi head attention layer 920 00:34:34,159 --> 00:34:39,039 where each attention head in the inside 921 00:34:36,559 --> 00:34:41,039 of it is tunable with those a matrices 922 00:34:39,039 --> 00:34:43,039 and then we have a residual connection. 923 00:34:41,039 --> 00:34:45,119 We do that and then we do layer norm and 924 00:34:43,039 --> 00:34:46,800 then we do the same thing in the next 925 00:34:45,119 --> 00:34:50,399 feed forward layer as well. And then 926 00:34:46,800 --> 00:34:52,159 boom out pops the output 927 00:34:50,398 --> 00:34:53,838 >> by that definition in the multi head 928 00:34:52,159 --> 00:34:56,398 attention layer when I'm doing tone and 929 00:34:53,838 --> 00:34:59,039 everything theoretically I can add even 930 00:34:56,398 --> 00:35:01,759 the biases or the hate speech aspects 931 00:34:59,039 --> 00:35:04,159 which come in to take care of it right 932 00:35:01,760 --> 00:35:06,320 so the model can account for the fact 933 00:35:04,159 --> 00:35:07,199 that something is biased or something is 934 00:35:06,320 --> 00:35:09,200 not 935 00:35:07,199 --> 00:35:11,679 >> um the thing is it's not so much the 936 00:35:09,199 --> 00:35:13,358 model is accounting for it is capturing 937 00:35:11,679 --> 00:35:16,719 whatever patterns happen to be inherent 938 00:35:13,358 --> 00:35:18,319 in the data it's capturing Right now 939 00:35:16,719 --> 00:35:19,598 what you do with that capture is up to 940 00:35:18,320 --> 00:35:21,838 you. It depends on the actual problem 941 00:35:19,599 --> 00:35:23,440 you're trying to solve. In particular, 942 00:35:21,838 --> 00:35:25,440 it is going to capture all the bad stuff 943 00:35:23,440 --> 00:35:27,119 too because if your training header has 944 00:35:25,440 --> 00:35:29,119 a lot of biased stuff in it, toxic 945 00:35:27,119 --> 00:35:30,480 things in it, dangerous things in it, it 946 00:35:29,119 --> 00:35:32,240 doesn't it doesn't have a sense of 947 00:35:30,480 --> 00:35:35,039 values as to what it's good or bad. It's 948 00:35:32,239 --> 00:35:36,239 just going to pick it up. 949 00:35:35,039 --> 00:35:38,480 >> Yes. 950 00:35:36,239 --> 00:35:40,799 >> On that then how do you actually make it 951 00:35:38,480 --> 00:35:43,119 angle on those or how do you mitigate 952 00:35:40,800 --> 00:35:44,800 the effect of those? That's a whole 953 00:35:43,119 --> 00:35:47,838 course unto itself, but I'm happy to 954 00:35:44,800 --> 00:35:50,560 give you pointers offline. 955 00:35:47,838 --> 00:35:52,799 All right, so this is what we have and 956 00:35:50,559 --> 00:35:54,960 remember what I said that this is just a 957 00:35:52,800 --> 00:35:56,480 single transformer block and since what 958 00:35:54,960 --> 00:35:58,240 comes in and what goes out are the same 959 00:35:56,480 --> 00:36:00,400 dimensions, we can just stack them one 960 00:35:58,239 --> 00:36:02,000 after the other, right? It's very 961 00:36:00,400 --> 00:36:03,280 stackable. You can do it, you can 962 00:36:02,000 --> 00:36:05,199 multiply, you can you can stack it 963 00:36:03,280 --> 00:36:08,000 vertically as much as you want. And as I 964 00:36:05,199 --> 00:36:09,919 mentioned, I think GPD3 has 96 of these 965 00:36:08,000 --> 00:36:14,079 things stacked one on top of the other. 966 00:36:09,920 --> 00:36:15,760 Um and so yeah that brings us to that is 967 00:36:14,079 --> 00:36:18,240 it that is the transformer encoder and 968 00:36:15,760 --> 00:36:20,160 this exactly maps to that. So basically 969 00:36:18,239 --> 00:36:22,239 the input embeddings come in you add 970 00:36:20,159 --> 00:36:24,480 positional embeddings and then you send 971 00:36:22,239 --> 00:36:26,399 it to say these many attention blocks 972 00:36:24,480 --> 00:36:28,800 and they all get added up and then it 973 00:36:26,400 --> 00:36:31,119 comes over the attention block you add 974 00:36:28,800 --> 00:36:32,320 the add and nom here means add means 975 00:36:31,119 --> 00:36:33,920 residual connection because you're 976 00:36:32,320 --> 00:36:36,079 adding the input which is why you have 977 00:36:33,920 --> 00:36:37,920 this arrow going from the input being 978 00:36:36,079 --> 00:36:39,920 added there and then you normalize it 979 00:36:37,920 --> 00:36:42,960 send it along and do it again and out 980 00:36:39,920 --> 00:36:46,480 comes the output. 981 00:36:42,960 --> 00:36:48,400 So all right now just to be very clear 982 00:36:46,480 --> 00:36:52,480 on what is being optimized during back 983 00:36:48,400 --> 00:36:54,559 propagation in this complex flow right 984 00:36:52,480 --> 00:36:56,320 now clearly the the embeddings that you 985 00:36:54,559 --> 00:36:57,838 started out with both the standalone 986 00:36:56,320 --> 00:37:00,000 embeddings as well as the positional uh 987 00:36:57,838 --> 00:37:01,838 the position embeddings those things are 988 00:37:00,000 --> 00:37:02,880 going to get optimized right those are 989 00:37:01,838 --> 00:37:05,279 just weights they're going to get 990 00:37:02,880 --> 00:37:06,800 optimized clearly everything inside the 991 00:37:05,280 --> 00:37:08,640 transformer encoder block is going to 992 00:37:06,800 --> 00:37:12,000 get get nominized right and what are 993 00:37:08,639 --> 00:37:15,598 they well they are the aqa v matrices 994 00:37:12,000 --> 00:37:18,079 for Each attention head layer norm has 995 00:37:15,599 --> 00:37:20,160 parameters as well. The next like the 996 00:37:18,079 --> 00:37:22,160 little feed forward layer has weights as 997 00:37:20,159 --> 00:37:24,960 well. All these things are going to get 998 00:37:22,159 --> 00:37:26,799 optimized and then it goes through this 999 00:37:24,960 --> 00:37:28,320 relu which again has a bunch of weights. 1000 00:37:26,800 --> 00:37:29,920 It's going to get optimized and then the 1001 00:37:28,320 --> 00:37:32,240 final softmax has a bunch of weights. 1002 00:37:29,920 --> 00:37:33,280 That's going to get optimized. 1003 00:37:32,239 --> 00:37:36,239 All these things are going to get 1004 00:37:33,280 --> 00:37:38,560 optimized by back prop. 1005 00:37:36,239 --> 00:37:40,000 Okay. So in that sense you just step 1006 00:37:38,559 --> 00:37:41,679 back for a second and look at the whole 1007 00:37:40,000 --> 00:37:43,760 thing. It is just a mathematical model 1008 00:37:41,679 --> 00:37:45,118 with a lot of parameters 1009 00:37:43,760 --> 00:37:46,480 and we're just going to use gradient 1010 00:37:45,119 --> 00:37:49,440 descent or stoastic gradient descent to 1011 00:37:46,480 --> 00:37:51,039 optimize it. That's it. 1012 00:37:49,440 --> 00:37:53,358 Yeah. 1013 00:37:51,039 --> 00:37:55,519 >> For those eight matrices we train the 1014 00:37:53,358 --> 00:37:58,559 model, are we calculating weights for 1015 00:37:55,519 --> 00:38:00,480 like each cell of every possible matrix 1016 00:37:58,559 --> 00:38:02,559 based on the number of inputs like every 1017 00:38:00,480 --> 00:38:04,559 possible dimension up to the max number 1018 00:38:02,559 --> 00:38:07,199 of inputs? 1019 00:38:04,559 --> 00:38:09,358 Um actually the the weights themselves 1020 00:38:07,199 --> 00:38:11,439 um don't depend on how long your input 1021 00:38:09,358 --> 00:38:13,519 sentence is because remember what we're 1022 00:38:11,440 --> 00:38:14,880 doing is for each sentence that comes in 1023 00:38:13,519 --> 00:38:16,800 let's say the sentence has say three 1024 00:38:14,880 --> 00:38:19,119 words there are three embeddings for 1025 00:38:16,800 --> 00:38:23,440 that sentence each of those embeddings 1026 00:38:19,119 --> 00:38:25,440 gets multiplied by say AK right so AK 1027 00:38:23,440 --> 00:38:27,679 only needs to work needs to know how 1028 00:38:25,440 --> 00:38:31,599 long is each embedding it doesn't need 1029 00:38:27,679 --> 00:38:33,039 to know how many words do I have 1030 00:38:31,599 --> 00:38:35,599 and that's a I'm glad you raised that 1031 00:38:33,039 --> 00:38:37,759 question Ben because that's what makes a 1032 00:38:35,599 --> 00:38:40,000 transformer's number of weights 1033 00:38:37,760 --> 00:38:42,160 independent of the number of words in 1034 00:38:40,000 --> 00:38:43,920 your sentence. 1035 00:38:42,159 --> 00:38:45,279 It only depends on the vocabulary that 1036 00:38:43,920 --> 00:38:46,960 you're going to work with because the 1037 00:38:45,280 --> 00:38:48,960 vocabulary determines how many 1038 00:38:46,960 --> 00:38:51,679 embeddings you need, how many embeddings 1039 00:38:48,960 --> 00:38:53,519 you need. It the length only matters in 1040 00:38:51,679 --> 00:38:55,039 terms of the positional embedding 1041 00:38:53,519 --> 00:38:56,639 because if you have a thousand long 1042 00:38:55,039 --> 00:38:59,199 sentence, you need a thousand long 1043 00:38:56,639 --> 00:39:02,239 positional embedding matrix. But beyond 1044 00:38:59,199 --> 00:39:04,480 that, it doesn't care. 1045 00:39:02,239 --> 00:39:07,679 And that's why for example Google uh 1046 00:39:04,480 --> 00:39:09,280 Gemini 1.5 Pro which is a million it can 1047 00:39:07,679 --> 00:39:12,078 accommodate basically a million long 1048 00:39:09,280 --> 00:39:15,200 million token context window right it 1049 00:39:12,079 --> 00:39:18,960 can it's still very compute heavy but it 1050 00:39:15,199 --> 00:39:20,719 does not change the number of parameters 1051 00:39:18,960 --> 00:39:24,159 uh yeah 1052 00:39:20,719 --> 00:39:26,319 >> conceptually which weights are optimized 1053 00:39:24,159 --> 00:39:28,639 first but in sequential order or are 1054 00:39:26,320 --> 00:39:29,680 they optimizing the weights at the very 1055 00:39:28,639 --> 00:39:31,920 same time all 1056 00:39:29,679 --> 00:39:34,078 >> simultaneously because if you think of 1057 00:39:31,920 --> 00:39:35,680 back propagation ultimately you have a 1058 00:39:34,079 --> 00:39:38,000 loss function right and you calculate 1059 00:39:35,679 --> 00:39:40,159 the gradient of that loss function so if 1060 00:39:38,000 --> 00:39:42,000 you have a say a billion parameters that 1061 00:39:40,159 --> 00:39:44,159 gradient is basically a billion long 1062 00:39:42,000 --> 00:39:47,039 vector right and we're going to take the 1063 00:39:44,159 --> 00:39:49,519 gradient and we're going to do w new 1064 00:39:47,039 --> 00:39:51,679 equals w old minus alpha times the 1065 00:39:49,519 --> 00:39:53,519 gradient so all the w's are going to 1066 00:39:51,679 --> 00:39:55,118 update instantaneously 1067 00:39:53,519 --> 00:39:56,880 now the way it actually works in 1068 00:39:55,119 --> 00:39:58,559 computation is you're going to do it the 1069 00:39:56,880 --> 00:39:59,599 because of the back and back propagation 1070 00:39:58,559 --> 00:40:01,759 it's going to start at the end and 1071 00:39:59,599 --> 00:40:03,920 slowly flow backwards but when it's done 1072 00:40:01,760 --> 00:40:06,720 everything will be updated. 1073 00:40:03,920 --> 00:40:10,159 Yeah. 1074 00:40:06,719 --> 00:40:12,559 >> We take uh two attention heads and we 1075 00:40:10,159 --> 00:40:16,319 have the matrices of AK, A2 and AV in 1076 00:40:12,559 --> 00:40:18,078 them. Uh why would the parameters of all 1077 00:40:16,320 --> 00:40:19,519 three of them all the weights of the 1078 00:40:18,079 --> 00:40:21,280 three matrices on this side and this 1079 00:40:19,519 --> 00:40:22,960 side would be different because finally 1080 00:40:21,280 --> 00:40:25,359 the things you're inputting from this 1081 00:40:22,960 --> 00:40:26,800 side and the output is same. So the 1082 00:40:25,358 --> 00:40:29,199 learning process should be ideally the 1083 00:40:26,800 --> 00:40:31,200 same unlike like a CNN where we had put 1084 00:40:29,199 --> 00:40:32,159 filters which were different. So what 1085 00:40:31,199 --> 00:40:35,279 different thing we have to 1086 00:40:32,159 --> 00:40:35,920 >> because the initialization is different. 1087 00:40:35,280 --> 00:40:37,119 >> What do we mean? 1088 00:40:35,920 --> 00:40:38,480 >> Like what I mean is if you have two 1089 00:40:37,119 --> 00:40:40,960 heads right each head has three 1090 00:40:38,480 --> 00:40:42,559 matrices. The starting values of those 1091 00:40:40,960 --> 00:40:45,599 six matrix is different. 1092 00:40:42,559 --> 00:40:46,559 >> Starting value of A aka B AQ and A is 1093 00:40:45,599 --> 00:40:48,800 different for both the heads 1094 00:40:46,559 --> 00:40:50,000 >> right? Much like for all the weights 1095 00:40:48,800 --> 00:40:53,119 typically the values are randomly 1096 00:40:50,000 --> 00:40:54,880 chosen. If they were all the same thing 1097 00:40:53,119 --> 00:40:56,000 you're right. It won't you don't make a 1098 00:40:54,880 --> 00:40:59,920 difference right? They will all change 1099 00:40:56,000 --> 00:41:02,639 the same way. Yeah. 1100 00:40:59,920 --> 00:41:06,720 U is the input of the transformer of the 1101 00:41:02,639 --> 00:41:08,239 sentence or the the array of embedding 1102 00:41:06,719 --> 00:41:10,639 of each word. 1103 00:41:08,239 --> 00:41:13,039 >> Uh the in the transformer itself is 1104 00:41:10,639 --> 00:41:14,480 expecting embeddings in and so what 1105 00:41:13,039 --> 00:41:16,639 basically happens is that we get some 1106 00:41:14,480 --> 00:41:18,639 sentence we run it through a tokenizer 1107 00:41:16,639 --> 00:41:20,879 which connects it to a bunch of tokens 1108 00:41:18,639 --> 00:41:22,719 which are just integers and then it goes 1109 00:41:20,880 --> 00:41:24,480 through the embedding layer which maps 1110 00:41:22,719 --> 00:41:26,239 the integers to these embeddings and 1111 00:41:24,480 --> 00:41:28,318 then you feed it to the transformer. But 1112 00:41:26,239 --> 00:41:29,598 when you do back propagation, it comes 1113 00:41:28,318 --> 00:41:31,358 all the way back to the starting 1114 00:41:29,599 --> 00:41:32,240 embedding layer and updates those 1115 00:41:31,358 --> 00:41:34,318 weights. 1116 00:41:32,239 --> 00:41:36,239 >> Okay. So they can be trainable. So the 1117 00:41:34,318 --> 00:41:37,358 twist at the beginning must be input 1118 00:41:36,239 --> 00:41:40,000 here, but they can train. 1119 00:41:37,358 --> 00:41:41,679 >> They're trainable. Exactly. Exactly. 1120 00:41:40,000 --> 00:41:43,920 >> Uh yeah. 1121 00:41:41,679 --> 00:41:45,519 >> Are the attention heads solely parallel 1122 00:41:43,920 --> 00:41:46,639 or can you have like a stack of 1123 00:41:45,519 --> 00:41:49,119 attention heads? 1124 00:41:46,639 --> 00:41:50,879 >> Typically they are parallelized. Um and 1125 00:41:49,119 --> 00:41:54,480 because you can always stack the block 1126 00:41:50,880 --> 00:41:57,200 itself to get more and more power. 1127 00:41:54,480 --> 00:41:59,519 All right. So um so now to apply the 1128 00:41:57,199 --> 00:42:01,919 transformer right there are common use 1129 00:41:59,519 --> 00:42:03,599 cases are that you have a whole sentence 1130 00:42:01,920 --> 00:42:05,519 that comes in and then you just want to 1131 00:42:03,599 --> 00:42:07,119 classify it right the the canonical 1132 00:42:05,519 --> 00:42:09,599 thing being hey movie sentiment 1133 00:42:07,119 --> 00:42:11,599 classification boom positive or negative 1134 00:42:09,599 --> 00:42:13,359 right classification another common one 1135 00:42:11,599 --> 00:42:15,838 is labeling where every word gets 1136 00:42:13,358 --> 00:42:17,279 labeled as a multiclass label and that's 1137 00:42:15,838 --> 00:42:19,119 basically what we saw with our slot 1138 00:42:17,280 --> 00:42:20,720 filling problem and then there is 1139 00:42:19,119 --> 00:42:22,160 another thing called sequence generation 1140 00:42:20,719 --> 00:42:23,838 where you give it a sequence you wanted 1141 00:42:22,159 --> 00:42:25,598 to continue the sequence right generate 1142 00:42:23,838 --> 00:42:28,159 more stuff i.e. large language models 1143 00:42:25,599 --> 00:42:29,359 and all that good stuff. So, so this we 1144 00:42:28,159 --> 00:42:30,559 know already know how to do because we 1145 00:42:29,358 --> 00:42:33,759 actually literally built a collab with 1146 00:42:30,559 --> 00:42:35,759 this with the transformer stack. Now the 1147 00:42:33,760 --> 00:42:37,280 question is how can we do that right? 1148 00:42:35,760 --> 00:42:40,160 How can you do basic classification with 1149 00:42:37,280 --> 00:42:42,000 these things? So now if you again when 1150 00:42:40,159 --> 00:42:44,000 you send a sentence in after all that 1151 00:42:42,000 --> 00:42:46,000 stuff is done and when I say encoder 1152 00:42:44,000 --> 00:42:48,079 here I'm assuming that you may have one 1153 00:42:46,000 --> 00:42:49,679 one block you may have 106 blocks I 1154 00:42:48,079 --> 00:42:50,880 don't care at the end of the day you 1155 00:42:49,679 --> 00:42:53,598 send something in you get a bunch of 1156 00:42:50,880 --> 00:42:57,200 contextual embeddings out 1157 00:42:53,599 --> 00:42:58,720 right so at this point we need to take 1158 00:42:57,199 --> 00:43:00,480 these contextual embeddings and somehow 1159 00:42:58,719 --> 00:43:02,078 make it work for classification for just 1160 00:43:00,480 --> 00:43:05,440 classifying something into yes or no 1161 00:43:02,079 --> 00:43:06,720 positive or negative so it'll be nice if 1162 00:43:05,440 --> 00:43:08,159 we can actually take all these 1163 00:43:06,719 --> 00:43:10,639 embeddings and like essentially 1164 00:43:08,159 --> 00:43:12,719 summarize them into a single embedding, 1165 00:43:10,639 --> 00:43:14,559 a single vector 1166 00:43:12,719 --> 00:43:16,239 because if you have a single vector then 1167 00:43:14,559 --> 00:43:18,078 we can run it through maybe a relu and 1168 00:43:16,239 --> 00:43:19,519 then we do a sigmoid and boom we can do 1169 00:43:18,079 --> 00:43:22,079 a you know a binary classification 1170 00:43:19,519 --> 00:43:23,759 problem super easy right so this begs 1171 00:43:22,079 --> 00:43:25,680 the question okay how are we going to go 1172 00:43:23,760 --> 00:43:28,720 from the all the many blue things to one 1173 00:43:25,679 --> 00:43:33,358 green thing 1174 00:43:28,719 --> 00:43:36,318 okay now of course um what we can do is 1175 00:43:33,358 --> 00:43:37,598 we can simply average them we can take 1176 00:43:36,318 --> 00:43:39,519 each of the embeddings just simply 1177 00:43:37,599 --> 00:43:42,960 average them element by element, you'll 1178 00:43:39,519 --> 00:43:47,318 get a nice green thing. Okay. Um any 1179 00:43:42,960 --> 00:43:47,318 shortcomings from doing that? 1180 00:43:48,318 --> 00:43:51,358 >> You would lose the ordering of the 1181 00:43:50,318 --> 00:43:53,759 words. 1182 00:43:51,358 --> 00:43:55,598 >> You do uh well in some sense the 1183 00:43:53,760 --> 00:43:58,079 positional embedding, the positional 1184 00:43:55,599 --> 00:44:00,640 encoding you have in the input does have 1185 00:43:58,079 --> 00:44:02,720 this notion of position, right? So 1186 00:44:00,639 --> 00:44:04,719 you're not necessarily losing the order 1187 00:44:02,719 --> 00:44:06,239 necessarily, but you're sort of 1188 00:44:04,719 --> 00:44:08,318 averaging all this information into 1189 00:44:06,239 --> 00:44:11,559 something and averaging is going to lose 1190 00:44:08,318 --> 00:44:11,559 some richness. 1191 00:44:12,800 --> 00:44:17,280 Okay. 1192 00:44:15,440 --> 00:44:19,920 >> I think it's going to be skewed to the 1193 00:44:17,280 --> 00:44:22,640 one that has like the biggest number, 1194 00:44:19,920 --> 00:44:23,760 right? So something is influencing your 1195 00:44:22,639 --> 00:44:25,279 >> Yeah, the biggest ones are going to 1196 00:44:23,760 --> 00:44:27,440 dominate. But hopefully we won't have 1197 00:44:25,280 --> 00:44:29,040 too much of that because all the layer 1198 00:44:27,440 --> 00:44:30,240 nom business at the beginning has 1199 00:44:29,039 --> 00:44:31,759 hopefully made sure the numbers are all 1200 00:44:30,239 --> 00:44:33,838 in a reasonably small and well behaved 1201 00:44:31,760 --> 00:44:35,599 range. But the the point really is that 1202 00:44:33,838 --> 00:44:36,960 you're going to lose richness in the 1203 00:44:35,599 --> 00:44:40,160 information because you're just like 1204 00:44:36,960 --> 00:44:42,880 mushing it down. So there's a much 1205 00:44:40,159 --> 00:44:46,639 better and more elegant way to do this 1206 00:44:42,880 --> 00:44:49,280 which is that what you do is for every 1207 00:44:46,639 --> 00:44:52,239 sentence when you train it you add an 1208 00:44:49,280 --> 00:44:54,640 artificial token called the class token. 1209 00:44:52,239 --> 00:44:57,199 Okay, literally it's an artificial token 1210 00:44:54,639 --> 00:45:00,318 and it's designated as you know CLS in 1211 00:44:57,199 --> 00:45:03,838 the literature and then this token is 1212 00:45:00,318 --> 00:45:06,400 getting trained with everything else. 1213 00:45:03,838 --> 00:45:08,000 Okay. And so once you once you finish 1214 00:45:06,400 --> 00:45:10,720 training 1215 00:45:08,000 --> 00:45:13,039 that token has its own embedding too. 1216 00:45:10,719 --> 00:45:15,039 And because it has been trained with 1217 00:45:13,039 --> 00:45:16,480 everything else and this token is 1218 00:45:15,039 --> 00:45:18,239 remember it's a contextual embedding 1219 00:45:16,480 --> 00:45:21,119 which means that it's very much aware of 1220 00:45:18,239 --> 00:45:23,358 all the other words in the sentence. 1221 00:45:21,119 --> 00:45:25,440 So in some sense this context this CLS 1222 00:45:23,358 --> 00:45:26,960 tokens contextual embedding sort of 1223 00:45:25,440 --> 00:45:29,119 captures everything that's going on 1224 00:45:26,960 --> 00:45:31,358 about that sentence 1225 00:45:29,119 --> 00:45:32,960 right and so what we do is once we are 1226 00:45:31,358 --> 00:45:35,759 done training we just grab this thing 1227 00:45:32,960 --> 00:45:38,960 alone and then send that through a relu 1228 00:45:35,760 --> 00:45:41,040 and a sigmoid and boom you're done. 1229 00:45:38,960 --> 00:45:43,599 So this is a very clever trick to 1230 00:45:41,039 --> 00:45:45,119 somehow you know instead of averaging 1231 00:45:43,599 --> 00:45:46,640 everything at the end let's just have 1232 00:45:45,119 --> 00:45:48,480 something just for the whole thing the 1233 00:45:46,639 --> 00:45:50,719 sentence and just learn it anyway along 1234 00:45:48,480 --> 00:45:52,800 with everything else. So in like a meta 1235 00:45:50,719 --> 00:45:54,480 principle in deep learning is that 1236 00:45:52,800 --> 00:45:55,760 whenever you think you're making an ad 1237 00:45:54,480 --> 00:45:56,960 hoc decision about something like 1238 00:45:55,760 --> 00:45:59,040 averaging a bunch of stuff you should 1239 00:45:56,960 --> 00:46:00,480 always stop and say is there a better 1240 00:45:59,039 --> 00:46:02,480 way to do it where it doesn't have to be 1241 00:46:00,480 --> 00:46:04,639 ad hoc where the right way is learnable 1242 00:46:02,480 --> 00:46:08,400 from the data directly using back 1243 00:46:04,639 --> 00:46:11,679 propagation. Um there was a hand. Yeah. 1244 00:46:08,400 --> 00:46:14,400 >> Is there a reason that you 1245 00:46:11,679 --> 00:46:15,039 added the CLS at the start? Why not add 1246 00:46:14,400 --> 00:46:16,559 it at the 1247 00:46:15,039 --> 00:46:17,039 >> You can do it at the end. Is there any 1248 00:46:16,559 --> 00:46:19,759 difference? 1249 00:46:17,039 --> 00:46:21,759 >> Um the only thing to remember is that um 1250 00:46:19,760 --> 00:46:22,800 it's a good question. So different 1251 00:46:21,760 --> 00:46:24,319 centers are going to be of different 1252 00:46:22,800 --> 00:46:25,200 length, right? So there might be short 1253 00:46:24,318 --> 00:46:27,039 sentences, there might be long 1254 00:46:25,199 --> 00:46:29,759 sentences. In particular, the lot the 1255 00:46:27,039 --> 00:46:31,599 short sentences are going to get padded, 1256 00:46:29,760 --> 00:46:34,079 right? I remember I talked about padding 1257 00:46:31,599 --> 00:46:35,680 to make it to fit to one length. So what 1258 00:46:34,079 --> 00:46:37,519 internally the transformer will do is 1259 00:46:35,679 --> 00:46:39,118 ignore all the padded tokens because it 1260 00:46:37,519 --> 00:46:40,800 doesn't do it's just padding doesn't 1261 00:46:39,119 --> 00:46:42,720 really matter for anything. So if you 1262 00:46:40,800 --> 00:46:44,079 have the serless at the very end we have 1263 00:46:42,719 --> 00:46:46,559 to have much more administrative 1264 00:46:44,079 --> 00:46:48,400 bookkeeping to take everything but the 1265 00:46:46,559 --> 00:46:50,318 last one 1266 00:46:48,400 --> 00:46:52,480 ignore it and only do the last one just 1267 00:46:50,318 --> 00:46:54,960 much easier just to get in the beginning 1268 00:46:52,480 --> 00:46:56,800 that's the reason. Yeah. 1269 00:46:54,960 --> 00:46:58,159 >> What would be just a practical 1270 00:46:56,800 --> 00:46:59,839 application of this would be something 1271 00:46:58,159 --> 00:47:00,559 like sentiment analysis like a positive 1272 00:46:59,838 --> 00:47:02,159 or negative. 1273 00:47:00,559 --> 00:47:04,480 >> Yeah. So basically any kind of text 1274 00:47:02,159 --> 00:47:06,078 comes in and you want to figure out some 1275 00:47:04,480 --> 00:47:08,000 labeling problem like a classification 1276 00:47:06,079 --> 00:47:09,920 problem. The easiest example I could 1277 00:47:08,000 --> 00:47:12,079 think of was sentiment. 1278 00:47:09,920 --> 00:47:14,079 But you can imagine for example an email 1279 00:47:12,079 --> 00:47:16,000 comes into a like a call center 1280 00:47:14,079 --> 00:47:17,200 operation and you want to take the email 1281 00:47:16,000 --> 00:47:20,960 and automatically figure out which 1282 00:47:17,199 --> 00:47:24,399 department should I send it to. 1283 00:47:20,960 --> 00:47:27,039 Okay. So now now if the input data for a 1284 00:47:24,400 --> 00:47:28,480 task is natural language text, right? We 1285 00:47:27,039 --> 00:47:31,199 don't have to restrict ourselves to only 1286 00:47:28,480 --> 00:47:32,880 the input training data we have. Right? 1287 00:47:31,199 --> 00:47:35,358 Would it be great to learn from all the 1288 00:47:32,880 --> 00:47:36,800 text that's out there? So, for example, 1289 00:47:35,358 --> 00:47:39,119 to go back to that call center thing I 1290 00:47:36,800 --> 00:47:41,039 just mentioned, you know, why clearly, 1291 00:47:39,119 --> 00:47:43,599 let's say it's coming in English, the 1292 00:47:41,039 --> 00:47:45,759 ability to take that English email and 1293 00:47:43,599 --> 00:47:47,280 route it to one of 10 things. You know, 1294 00:47:45,760 --> 00:47:49,119 you should have to learn English just 1295 00:47:47,280 --> 00:47:50,640 for your call center application. You 1296 00:47:49,119 --> 00:47:52,800 should learn English generally and use 1297 00:47:50,639 --> 00:47:54,318 it for other things, right? So, why 1298 00:47:52,800 --> 00:47:56,880 can't we just learn from all the text 1299 00:47:54,318 --> 00:47:58,239 that's out there? And so, that brings us 1300 00:47:56,880 --> 00:48:00,079 to something called self-supervised 1301 00:47:58,239 --> 00:48:02,318 learning. And the idea of sens 1302 00:48:00,079 --> 00:48:03,760 supervised learning is this. So if you 1303 00:48:02,318 --> 00:48:05,838 recall the transfer learning example 1304 00:48:03,760 --> 00:48:08,400 from lecture four right where we had 1305 00:48:05,838 --> 00:48:10,480 restnet right and we took restn net we 1306 00:48:08,400 --> 00:48:13,039 chopped off the final thing we make made 1307 00:48:10,480 --> 00:48:14,800 it sort of headless and then we attached 1308 00:48:13,039 --> 00:48:17,519 that output of the headless restn net to 1309 00:48:14,800 --> 00:48:19,760 a little hidden layer and output and we 1310 00:48:17,519 --> 00:48:21,039 did the handbags and shoes and you will 1311 00:48:19,760 --> 00:48:22,480 recall that we were able to build a very 1312 00:48:21,039 --> 00:48:24,880 good classifier for handbags and shoes 1313 00:48:22,480 --> 00:48:26,639 with just like a 100 examples. Right? So 1314 00:48:24,880 --> 00:48:29,280 the question is why was this so 1315 00:48:26,639 --> 00:48:31,519 effective? Why was this so effective? 1316 00:48:29,280 --> 00:48:34,079 And turns out the reason why any of this 1317 00:48:31,519 --> 00:48:36,400 stuff actually works is because neural 1318 00:48:34,079 --> 00:48:38,160 networks or they learn representations 1319 00:48:36,400 --> 00:48:40,318 automatically when you train them. So 1320 00:48:38,159 --> 00:48:42,399 what I mean by that is when you imagine 1321 00:48:40,318 --> 00:48:43,759 a network, you feed in a bunch of stuff, 1322 00:48:42,400 --> 00:48:46,639 it goes through all the layers, it comes 1323 00:48:43,760 --> 00:48:48,960 out. Uh you can think of each layer as 1324 00:48:46,639 --> 00:48:50,400 transforming the raw input in some 1325 00:48:48,960 --> 00:48:53,280 different alternate representation of 1326 00:48:50,400 --> 00:48:54,480 the input. Okay? And so and these are 1327 00:48:53,280 --> 00:48:57,200 called representations. That's actually 1328 00:48:54,480 --> 00:48:58,960 a technical term. Um, and so you can 1329 00:48:57,199 --> 00:49:00,558 from this perspective when you train a a 1330 00:48:58,960 --> 00:49:02,880 neural network, a deep network with lots 1331 00:49:00,559 --> 00:49:05,920 of layers, what you're really learning 1332 00:49:02,880 --> 00:49:07,838 is you're learning a way to you're 1333 00:49:05,920 --> 00:49:09,440 learning how to represent the input in 1334 00:49:07,838 --> 00:49:10,400 many different ways. Each of these 1335 00:49:09,440 --> 00:49:11,838 arrows is a different way of 1336 00:49:10,400 --> 00:49:14,240 representing things. Plus, you're 1337 00:49:11,838 --> 00:49:15,759 learning a final regression model, 1338 00:49:14,239 --> 00:49:16,799 either a linear regression model or a 1339 00:49:15,760 --> 00:49:18,079 logistic regression model. 1340 00:49:16,800 --> 00:49:19,680 Fundamentally, that's what's going on. 1341 00:49:18,079 --> 00:49:21,599 Because the final layers tend to be 1342 00:49:19,679 --> 00:49:24,000 sigmoid, soft max, or just linear, 1343 00:49:21,599 --> 00:49:26,480 right? So the final layer if you just 1344 00:49:24,000 --> 00:49:27,760 look at the this part alone whatever is 1345 00:49:26,480 --> 00:49:29,119 coming in it's just going through 1346 00:49:27,760 --> 00:49:31,520 essentially a linear regression model or 1347 00:49:29,119 --> 00:49:32,800 a logistic regression model that's it. 1348 00:49:31,519 --> 00:49:34,318 So fundamentally you're learning 1349 00:49:32,800 --> 00:49:36,720 representations and a final little 1350 00:49:34,318 --> 00:49:38,079 model. Okay. But the reason why all 1351 00:49:36,719 --> 00:49:39,358 these things work so much better than 1352 00:49:38,079 --> 00:49:41,359 logistic regression is because those 1353 00:49:39,358 --> 00:49:43,598 representations have learned all kinds 1354 00:49:41,358 --> 00:49:45,358 of useful things about the input data. 1355 00:49:43,599 --> 00:49:47,519 They have sort of automatically feature 1356 00:49:45,358 --> 00:49:50,078 engineered for you. 1357 00:49:47,519 --> 00:49:53,119 So, so from this perspective you can 1358 00:49:50,079 --> 00:49:55,280 imagine that each layer here is like an 1359 00:49:53,119 --> 00:49:56,800 encoder. It encodes the input, right? 1360 00:49:55,280 --> 00:49:58,240 The first layer encodes it. The first 1361 00:49:56,800 --> 00:49:59,519 two layers encode something. The first 1362 00:49:58,239 --> 00:50:01,439 three layers encode something and so on 1363 00:49:59,519 --> 00:50:04,318 and so forth. So a deep network contains 1364 00:50:01,440 --> 00:50:06,639 many encoders. And so the question is 1365 00:50:04,318 --> 00:50:08,719 what do these representations actually 1366 00:50:06,639 --> 00:50:10,639 embody right? What do they capture? Is 1367 00:50:08,719 --> 00:50:12,719 it like specific knowledge about the 1368 00:50:10,639 --> 00:50:14,879 particular problem that you train the 1369 00:50:12,719 --> 00:50:16,399 thing train the network on or is it like 1370 00:50:14,880 --> 00:50:18,160 general knowledge about the input data 1371 00:50:16,400 --> 00:50:20,160 because if it is general knowledge about 1372 00:50:18,159 --> 00:50:22,719 the input we can use it to solve other 1373 00:50:20,159 --> 00:50:24,318 problems unrelated problems. So is it 1374 00:50:22,719 --> 00:50:26,480 specific knowledge or general knowledge 1375 00:50:24,318 --> 00:50:28,558 and it turns out they actually capture a 1376 00:50:26,480 --> 00:50:31,039 lot of general knowledge about the input 1377 00:50:28,559 --> 00:50:33,040 and that's why you can get reuse out of 1378 00:50:31,039 --> 00:50:34,558 them you can reuse them for other 1379 00:50:33,039 --> 00:50:36,400 unrelated things because they have 1380 00:50:34,559 --> 00:50:38,240 captured general stuff. So if you look 1381 00:50:36,400 --> 00:50:40,160 at this, I think I've shown you before, 1382 00:50:38,239 --> 00:50:41,759 right? If you if you look at a network 1383 00:50:40,159 --> 00:50:43,358 that classifies everyday objects into a 1384 00:50:41,760 --> 00:50:44,720 bunch of categories, it can learn all 1385 00:50:43,358 --> 00:50:46,799 these little patterns in the beginning 1386 00:50:44,719 --> 00:50:48,558 and later on and so on and so forth. And 1387 00:50:46,800 --> 00:50:50,480 this is a face detection network. It has 1388 00:50:48,559 --> 00:50:52,640 learned how to look at, you know, 1389 00:50:50,480 --> 00:50:55,280 identify little circles and edges and 1390 00:50:52,639 --> 00:50:56,639 nose like shapes and finally faces. So 1391 00:50:55,280 --> 00:50:57,760 all these things are examples of 1392 00:50:56,639 --> 00:51:00,960 representations, learning interesting 1393 00:50:57,760 --> 00:51:02,480 things about the input. Okay. So since 1394 00:51:00,960 --> 00:51:04,240 these representations are capturing 1395 00:51:02,480 --> 00:51:06,960 intrinsic aspects of the data, you can 1396 00:51:04,239 --> 00:51:08,719 use it for other things, right? You can 1397 00:51:06,960 --> 00:51:10,559 take a face detection neural network and 1398 00:51:08,719 --> 00:51:12,318 use it, reuse it for emotion detection 1399 00:51:10,559 --> 00:51:14,640 for instance. 1400 00:51:12,318 --> 00:51:17,358 U so the question is if you can somehow 1401 00:51:14,639 --> 00:51:19,358 get like an encoder that generates good 1402 00:51:17,358 --> 00:51:20,799 representations for your input data, we 1403 00:51:19,358 --> 00:51:22,558 can simply build a regression model with 1404 00:51:20,800 --> 00:51:24,079 those as input and labels as output and 1405 00:51:22,559 --> 00:51:27,359 be done. And this is exactly what we did 1406 00:51:24,079 --> 00:51:28,960 with RestNet for handbags and shows. We 1407 00:51:27,358 --> 00:51:30,799 found a thing that had already been 1408 00:51:28,960 --> 00:51:33,679 trained on similar everyday objects, 1409 00:51:30,800 --> 00:51:35,200 everyday images. And the key insight 1410 00:51:33,679 --> 00:51:37,279 here is that since we don't have to 1411 00:51:35,199 --> 00:51:40,078 spend precious data on learning these 1412 00:51:37,280 --> 00:51:42,160 good representations, 1413 00:51:40,079 --> 00:51:44,880 we won't need as much label data in the 1414 00:51:42,159 --> 00:51:46,318 first place because the pre-training 1415 00:51:44,880 --> 00:51:48,318 used a lot of data and you're sort of 1416 00:51:46,318 --> 00:51:50,239 piggybacking on that data. So in some 1417 00:51:48,318 --> 00:51:51,599 sense, your training data is everything 1418 00:51:50,239 --> 00:51:55,358 that the pre-trained model was trained 1419 00:51:51,599 --> 00:51:57,119 on plus your little 200 examples. 1420 00:51:55,358 --> 00:51:58,558 Um, okay. So this is what we did. We 1421 00:51:57,119 --> 00:52:00,160 used headless resonate as an encoder 1422 00:51:58,559 --> 00:52:02,480 that can take raw input and transform it 1423 00:52:00,159 --> 00:52:04,639 into useful representations. Uh this is 1424 00:52:02,480 --> 00:52:06,318 what we did. All right. So the general 1425 00:52:04,639 --> 00:52:08,000 approach is that you find a deep neural 1426 00:52:06,318 --> 00:52:10,719 network built on similar inputs but 1427 00:52:08,000 --> 00:52:13,119 different outputs. Uh and then you 1428 00:52:10,719 --> 00:52:15,439 basically grab maybe the penultimate uh 1429 00:52:13,119 --> 00:52:17,760 representation or the one before that. 1430 00:52:15,440 --> 00:52:21,119 Then you chop off the head. You attach 1431 00:52:17,760 --> 00:52:23,119 your own output head. Train the whole 1432 00:52:21,119 --> 00:52:25,039 thing just the final layer or train the 1433 00:52:23,119 --> 00:52:26,079 whole thing if you want. Right? This is 1434 00:52:25,039 --> 00:52:27,838 like the playbook we followed for 1435 00:52:26,079 --> 00:52:30,720 restnet. The same thing works for all 1436 00:52:27,838 --> 00:52:32,318 kinds of other data types as well. So 1437 00:52:30,719 --> 00:52:34,000 now to build such a model we need 1438 00:52:32,318 --> 00:52:35,599 labeled data, right? We were lucky 1439 00:52:34,000 --> 00:52:37,599 because restnet was actually trained on 1440 00:52:35,599 --> 00:52:39,119 imageet data which is like a million 1441 00:52:37,599 --> 00:52:40,960 images each of which labeled into 1442 00:52:39,119 --> 00:52:44,880 thousand categories which is very 1443 00:52:40,960 --> 00:52:46,639 convenient for us, right? But what if 1444 00:52:44,880 --> 00:52:49,760 you want to build a generally useful 1445 00:52:46,639 --> 00:52:51,279 model for text data? 1446 00:52:49,760 --> 00:52:52,559 Clearly we need to collect a lot of text 1447 00:52:51,280 --> 00:52:54,160 data. But that's no problem because 1448 00:52:52,559 --> 00:52:55,680 internet is full of text data, right? we 1449 00:52:54,159 --> 00:52:57,519 can easily escape the internet. We can 1450 00:52:55,679 --> 00:52:59,759 just download Wikipedia. So that's not a 1451 00:52:57,519 --> 00:53:02,559 problem. The problem is something else 1452 00:52:59,760 --> 00:53:05,520 which is that how do we define an input 1453 00:53:02,559 --> 00:53:07,119 label for a piece of text? So for an 1454 00:53:05,519 --> 00:53:09,199 input sentence, what should the output 1455 00:53:07,119 --> 00:53:10,480 label be? That's the key question. 1456 00:53:09,199 --> 00:53:11,759 Because if you can answer this question, 1457 00:53:10,480 --> 00:53:14,318 you can just spray train all these 1458 00:53:11,760 --> 00:53:17,520 things on all kinds of text data, right? 1459 00:53:14,318 --> 00:53:18,800 So the like a beautiful idea for doing 1460 00:53:17,519 --> 00:53:20,880 this is called self-supervised learning. 1461 00:53:18,800 --> 00:53:23,359 And the key idea is that you take your 1462 00:53:20,880 --> 00:53:26,079 input, whatever the input is you take a 1463 00:53:23,358 --> 00:53:28,719 small part of the input and just remove 1464 00:53:26,079 --> 00:53:31,680 it and then ask your network to fill in 1465 00:53:28,719 --> 00:53:33,919 the blanks from everything else. 1466 00:53:31,679 --> 00:53:35,118 Okay, so this is called masking and it's 1467 00:53:33,920 --> 00:53:36,559 just one of many techniques in 1468 00:53:35,119 --> 00:53:39,119 self-supervised learning, but this is 1469 00:53:36,559 --> 00:53:41,680 very commonly used. So this is original 1470 00:53:39,119 --> 00:53:43,599 input, right? And then you take it and 1471 00:53:41,679 --> 00:53:45,679 then you just like take this thing in 1472 00:53:43,599 --> 00:53:48,880 the middle here randomly and and and 1473 00:53:45,679 --> 00:53:51,199 zero it out or mask it. And so this 1474 00:53:48,880 --> 00:53:53,119 incomplete input is your now new input 1475 00:53:51,199 --> 00:53:56,719 and the thing that you took out becomes 1476 00:53:53,119 --> 00:53:58,240 your your fake label. 1477 00:53:56,719 --> 00:54:00,399 So you can almost imagine right if you 1478 00:53:58,239 --> 00:54:02,318 take if you if you're baking donuts you 1479 00:54:00,400 --> 00:54:04,720 you make a donut and then you punch a 1480 00:54:02,318 --> 00:54:07,199 hole in the middle of the donut the the 1481 00:54:04,719 --> 00:54:11,598 donut with the hole is your no input the 1482 00:54:07,199 --> 00:54:13,039 munchkin is the label. 1483 00:54:11,599 --> 00:54:15,200 Am I making everybody hungry at this 1484 00:54:13,039 --> 00:54:17,838 point? So, 1485 00:54:15,199 --> 00:54:19,519 so and once you do that, no problem. You 1486 00:54:17,838 --> 00:54:23,799 have an input, you have an you have 1487 00:54:19,519 --> 00:54:23,800 labels, you just train a neural network 1488 00:54:23,838 --> 00:54:28,558 to essentially predict those to 1489 00:54:25,679 --> 00:54:30,879 basically fill in the blanks. 1490 00:54:28,559 --> 00:54:32,559 And so if for example, if you take a 1491 00:54:30,880 --> 00:54:34,640 sentence like the Sloan School's 1492 00:54:32,559 --> 00:54:36,559 mission, you can just go in there and 1493 00:54:34,639 --> 00:54:39,199 just just knock out randomly a bunch of 1494 00:54:36,559 --> 00:54:40,319 words like this second. And the ones I'm 1495 00:54:39,199 --> 00:54:42,960 knocking out, I'm just putting the word 1496 00:54:40,318 --> 00:54:45,119 mask in it just to show what I'm doing. 1497 00:54:42,960 --> 00:54:46,720 And then what it's actually given this 1498 00:54:45,119 --> 00:54:50,240 sentence, it will try to fill in the 1499 00:54:46,719 --> 00:54:51,759 blanks with actual words. 1500 00:54:50,239 --> 00:54:53,439 Okay, 1501 00:54:51,760 --> 00:54:54,400 so now for the amazing part. In the 1502 00:54:53,440 --> 00:54:57,358 process of learning to fill in the 1503 00:54:54,400 --> 00:54:58,960 blanks, uh the network learns a really 1504 00:54:57,358 --> 00:55:01,199 good representation of the kind of input 1505 00:54:58,960 --> 00:55:02,880 data it's seeing. And it kind of makes 1506 00:55:01,199 --> 00:55:04,879 sense, right? Because if I give you a 1507 00:55:02,880 --> 00:55:06,800 sentence with a few missing blanks and 1508 00:55:04,880 --> 00:55:08,720 you're able to very successfully fill in 1509 00:55:06,800 --> 00:55:10,079 the blanks, you have learned a whole 1510 00:55:08,719 --> 00:55:12,558 bunch of stuff about the world to be 1511 00:55:10,079 --> 00:55:14,318 able to do that, right? If I say the 1512 00:55:12,559 --> 00:55:16,800 capital of France is Dash and you're 1513 00:55:14,318 --> 00:55:18,880 like Paris, okay, how did you know that? 1514 00:55:16,800 --> 00:55:20,559 It's sort of like that. By learning to 1515 00:55:18,880 --> 00:55:22,559 fill in the blanks, you really have to 1516 00:55:20,559 --> 00:55:24,079 learn how how all these things work, all 1517 00:55:22,559 --> 00:55:27,760 the the connections between various 1518 00:55:24,079 --> 00:55:29,839 words and so on and so forth. So, and so 1519 00:55:27,760 --> 00:55:32,000 what you can do is once we build such a 1520 00:55:29,838 --> 00:55:34,159 model, we can just extract an encoder 1521 00:55:32,000 --> 00:55:36,079 from it, right? And then we'll fine-tune 1522 00:55:34,159 --> 00:55:38,239 it like we do with library transfer 1523 00:55:36,079 --> 00:55:41,359 learning. But this how you build a 1524 00:55:38,239 --> 00:55:43,598 generic a generic pre-trained model on 1525 00:55:41,358 --> 00:55:46,159 unlabelled data. 1526 00:55:43,599 --> 00:55:48,000 And so we can use a transformer encoder 1527 00:55:46,159 --> 00:55:49,598 to build this whole thing in the middle 1528 00:55:48,000 --> 00:55:51,599 because remember the transformer can 1529 00:55:49,599 --> 00:55:53,280 take any sentence and give you the same 1530 00:55:51,599 --> 00:55:55,280 size sentence back along with 1531 00:55:53,280 --> 00:55:57,280 predictions for everything. So we can 1532 00:55:55,280 --> 00:55:58,880 just have it take this thing in and ask 1533 00:55:57,280 --> 00:56:01,519 it to just predict all the missing words 1534 00:55:58,880 --> 00:56:03,119 here. 1535 00:56:01,519 --> 00:56:05,358 And 1536 00:56:03,119 --> 00:56:06,880 so uh to put it in other words, masked 1537 00:56:05,358 --> 00:56:09,440 self-supervised learning is just a 1538 00:56:06,880 --> 00:56:11,039 sequence labeling problem. 1539 00:56:09,440 --> 00:56:13,440 So basically this is the sequence that 1540 00:56:11,039 --> 00:56:14,639 comes in and then you you tell the 1541 00:56:13,440 --> 00:56:16,240 transform and you get all these 1542 00:56:14,639 --> 00:56:18,078 embeddings. It goes through all that 1543 00:56:16,239 --> 00:56:21,118 stuff. You really don't care about these 1544 00:56:18,079 --> 00:56:23,359 outputs. But wherever the word mask went 1545 00:56:21,119 --> 00:56:25,358 in in the input, you you basically try 1546 00:56:23,358 --> 00:56:26,798 to get it to the right answer is for 1547 00:56:25,358 --> 00:56:28,159 example the word mission and you're 1548 00:56:26,798 --> 00:56:29,759 trying to and that is the right answer. 1549 00:56:28,159 --> 00:56:31,440 This is the right answer here. And then 1550 00:56:29,760 --> 00:56:32,799 you take these right answers, create a 1551 00:56:31,440 --> 00:56:35,519 loss function, and do back prop and 1552 00:56:32,798 --> 00:56:37,358 boom, you're done. 1553 00:56:35,519 --> 00:56:40,159 Inputs, right answers, and and you're in 1554 00:56:37,358 --> 00:56:41,759 business. That's it. Now, if we 1555 00:56:40,159 --> 00:56:44,078 pre-train a transformer model like this 1556 00:56:41,760 --> 00:56:46,240 on massive amounts of English text, 1557 00:56:44,079 --> 00:56:48,960 let's say we did that. We get something 1558 00:56:46,239 --> 00:56:51,118 called BERT. BERT is a very famous 1559 00:56:48,960 --> 00:56:53,599 transformer model. And BERT was the 1560 00:56:51,119 --> 00:56:56,400 first model actually that Google used to 1561 00:56:53,599 --> 00:56:58,559 upgrade its search in 2019. 1562 00:56:56,400 --> 00:57:00,318 like the br the Brazil visa example you 1563 00:56:58,559 --> 00:57:03,599 may recall from earlier lectures that 1564 00:57:00,318 --> 00:57:06,400 uses BERT under the hood. Okay. Um and 1565 00:57:03,599 --> 00:57:07,920 so now I just want to show you because 1566 00:57:06,400 --> 00:57:09,680 you can actually read the BERT paper and 1567 00:57:07,920 --> 00:57:10,880 it'll actually make sense to you now 1568 00:57:09,679 --> 00:57:13,440 based on what you have learned in this 1569 00:57:10,880 --> 00:57:14,798 class. Look at this BERT's model 1570 00:57:13,440 --> 00:57:16,798 architecture is a multi-layer 1571 00:57:14,798 --> 00:57:18,639 birectional transformer encoder. Okay, 1572 00:57:16,798 --> 00:57:20,639 transformer encoder. We denote the 1573 00:57:18,639 --> 00:57:23,039 number of layers transformer blocks as 1574 00:57:20,639 --> 00:57:25,118 L. The hidden size is H and the number 1575 00:57:23,039 --> 00:57:30,558 of attention heads as A. And how much is 1576 00:57:25,119 --> 00:57:34,318 that? Uh okay we want uh h is 768 okay 1577 00:57:30,559 --> 00:57:36,480 so which means that the embedding sizes 1578 00:57:34,318 --> 00:57:38,318 or 768 1579 00:57:36,480 --> 00:57:41,599 and the hidden feed forward layer is 1580 00:57:38,318 --> 00:57:44,719 four times as much so it's 4096 and so 1581 00:57:41,599 --> 00:57:47,760 sorry the the the 4096 the feed forward 1582 00:57:44,719 --> 00:57:49,838 layer the embeddings are 768 and you can 1583 00:57:47,760 --> 00:57:52,799 see there are two BERT models here this 1584 00:57:49,838 --> 00:57:55,759 one has 12 transformer blocks this one 1585 00:57:52,798 --> 00:57:58,159 has 24 transformer blocks 1586 00:57:55,760 --> 00:57:59,440 Okay, so you can actually read this 1587 00:57:58,159 --> 00:58:00,879 paper. You can you can actually relate 1588 00:57:59,440 --> 00:58:02,720 it to exactly what we discussed in 1589 00:58:00,880 --> 00:58:04,640 class. It'll all make sense. 1590 00:58:02,719 --> 00:58:06,239 Birectionally means that the words can 1591 00:58:04,639 --> 00:58:09,598 pay attention to every other word in the 1592 00:58:06,239 --> 00:58:10,959 sentence. And as we will see on Monday, 1593 00:58:09,599 --> 00:58:12,400 you can have you have a diff another 1594 00:58:10,960 --> 00:58:14,240 transformer thing called a causal 1595 00:58:12,400 --> 00:58:15,440 transformer in which you only pay 1596 00:58:14,239 --> 00:58:18,000 attention to the words that came before 1597 00:58:15,440 --> 00:58:21,014 you, not the ones after you. So 1598 00:58:18,000 --> 00:58:24,400 birectional means all words are seen. 1599 00:58:21,014 --> 00:58:26,639 [snorts] Okay. So um so what we do is 1600 00:58:24,400 --> 00:58:27,760 remember we said to do solve sequence 1601 00:58:26,639 --> 00:58:30,719 classification you can add a little 1602 00:58:27,760 --> 00:58:32,480 token at the beginning uh and then boom 1603 00:58:30,719 --> 00:58:35,199 use it for classification as it turns 1604 00:58:32,480 --> 00:58:36,960 out but very conveniently for us the 1605 00:58:35,199 --> 00:58:38,639 people who built bird they actually auto 1606 00:58:36,960 --> 00:58:41,039 they when they train bird they just use 1607 00:58:38,639 --> 00:58:42,318 the CLS business 1608 00:58:41,039 --> 00:58:44,719 during training so it's actually 1609 00:58:42,318 --> 00:58:46,159 available for us out of the box so when 1610 00:58:44,719 --> 00:58:47,439 you use bird for sequence classification 1611 00:58:46,159 --> 00:58:48,798 you don't even have to do any surgery on 1612 00:58:47,440 --> 00:58:51,519 it it just gives you the class token 1613 00:58:48,798 --> 00:58:52,960 automatically which is very convenient 1614 00:58:51,519 --> 00:58:55,280 uh and you can also use it for sequence 1615 00:58:52,960 --> 00:58:57,440 labeling as well. So for sequence 1616 00:58:55,280 --> 00:58:58,960 classifications and sequence labeling uh 1617 00:58:57,440 --> 00:59:00,960 BERT is actually usually a really good 1618 00:58:58,960 --> 00:59:02,159 starting point and in particular there 1619 00:59:00,960 --> 00:59:04,240 have been lots of improvements and 1620 00:59:02,159 --> 00:59:05,759 variations of BERT over the years and if 1621 00:59:04,239 --> 00:59:07,199 you're curious about this there's a 1622 00:59:05,760 --> 00:59:09,040 thing called the sentence transformers 1623 00:59:07,199 --> 00:59:11,199 library which has got a whole bunch of 1624 00:59:09,039 --> 00:59:14,400 BERT related code and resources that you 1625 00:59:11,199 --> 00:59:18,480 can use to do things out of the box. 1626 00:59:14,400 --> 00:59:20,000 Okay. So okay there's a bit of a word 1627 00:59:18,480 --> 00:59:21,920 wall. 1628 00:59:20,000 --> 00:59:23,519 So to solve any of these problems 1629 00:59:21,920 --> 00:59:24,720 classification or labeling where the 1630 00:59:23,519 --> 00:59:27,199 input is natural language we can 1631 00:59:24,719 --> 00:59:28,719 obviously use a model like BERT label a 1632 00:59:27,199 --> 00:59:30,159 few hundred examples attach the right 1633 00:59:28,719 --> 00:59:32,318 final layers and fine tune it like we 1634 00:59:30,159 --> 00:59:34,879 did for the restn net but if your 1635 00:59:32,318 --> 00:59:37,358 problem is like a standard NLP problem 1636 00:59:34,880 --> 00:59:39,280 okay you don't even have to do that 1637 00:59:37,358 --> 00:59:40,719 because people for these standard tasks 1638 00:59:39,280 --> 00:59:43,440 they've already pre-trained it on those 1639 00:59:40,719 --> 00:59:44,719 standard tasks right and so you can do 1640 00:59:43,440 --> 00:59:47,440 all these things without any fine tuning 1641 00:59:44,719 --> 00:59:49,199 at all like literally out of the box u 1642 00:59:47,440 --> 00:59:50,720 and so there are many hubs which have 1643 00:59:49,199 --> 00:59:53,519 these pre-trained models, but perhaps 1644 00:59:50,719 --> 00:59:56,558 the biggest one is the hugging face hub. 1645 00:59:53,519 --> 00:59:58,159 And I checked last night, it has 525,000 1646 00:59:56,559 --> 01:00:00,640 models 1647 00:59:58,159 --> 01:00:02,239 available. I think if I recall last year 1648 01:00:00,639 --> 01:00:04,719 when I taught Hodel, I think the number 1649 01:00:02,239 --> 01:00:07,118 was a lot smaller, maybe 50,000. So it's 1650 01:00:04,719 --> 01:00:09,039 like growing really, really fast. Um, 1651 01:00:07,119 --> 01:00:12,599 and so all right, let's just switch to a 1652 01:00:09,039 --> 01:00:12,599 hugging face collab. 1653 01:00:15,199 --> 01:00:21,759 So, hugging face, how many of you are 1654 01:00:18,159 --> 01:00:24,719 familiar with hugging face? 1655 01:00:21,760 --> 01:00:26,720 Okay, it's good. All right, so um for 1656 01:00:24,719 --> 01:00:28,480 the others, basically you have a whole 1657 01:00:26,719 --> 01:00:30,318 bunch of pre-trained models on hugging 1658 01:00:28,480 --> 01:00:32,240 phase. You actually have a lot of data 1659 01:00:30,318 --> 01:00:34,960 sets you can work with for your own 1660 01:00:32,239 --> 01:00:37,039 tasks. Uh there are lots of people 1661 01:00:34,960 --> 01:00:39,039 demoing what they have built in this 1662 01:00:37,039 --> 01:00:40,558 thing called spaces and of course a lot 1663 01:00:39,039 --> 01:00:42,318 of documentation and so on. So the thing 1664 01:00:40,559 --> 01:00:44,000 you can do is what they have done is 1665 01:00:42,318 --> 01:00:46,318 they have organized all these models by 1666 01:00:44,000 --> 01:00:47,760 the kind of task you can use them for. 1667 01:00:46,318 --> 01:00:49,279 So you can see here there are a whole 1668 01:00:47,760 --> 01:00:50,960 bunch of computer vision tasks that you 1669 01:00:49,280 --> 01:00:52,480 can use them for. There's a whole bunch 1670 01:00:50,960 --> 01:00:54,000 of natural language tasks like text 1671 01:00:52,480 --> 01:00:56,798 classification 1672 01:00:54,000 --> 01:00:59,280 uh feature extraction this and that lots 1673 01:00:56,798 --> 01:01:00,559 of interesting examples here. And so 1674 01:00:59,280 --> 01:01:01,760 what you do is you just literally can go 1675 01:01:00,559 --> 01:01:03,839 in there and say okay I want to do a 1676 01:01:01,760 --> 01:01:05,200 text classification. You hit it and then 1677 01:01:03,838 --> 01:01:06,798 it tells you all the models that are 1678 01:01:05,199 --> 01:01:08,558 available. Turns into 50,000 models just 1679 01:01:06,798 --> 01:01:10,159 for text classification. And you can 1680 01:01:08,559 --> 01:01:11,680 look at okay which is you know most 1681 01:01:10,159 --> 01:01:13,118 downloaded or which is the most liked 1682 01:01:11,679 --> 01:01:14,318 and then you can just use them as a 1683 01:01:13,119 --> 01:01:17,358 starting point for whatever you want to 1684 01:01:14,318 --> 01:01:20,880 do. Okay. So so that is hugging phase 1685 01:01:17,358 --> 01:01:24,960 and so the way you do hugging face is 1686 01:01:20,880 --> 01:01:26,798 I'm just connecting it. Um 1687 01:01:24,960 --> 01:01:28,159 if you have a problem which the input is 1688 01:01:26,798 --> 01:01:29,440 natural language text the first question 1689 01:01:28,159 --> 01:01:31,199 you have to ask yourself is it standard 1690 01:01:29,440 --> 01:01:32,960 or not? Is it a standard task or not? If 1691 01:01:31,199 --> 01:01:34,639 it's a standard task you just go go that 1692 01:01:32,960 --> 01:01:37,199 do not reinvent the wheel. This thing 1693 01:01:34,639 --> 01:01:39,679 will usually work pretty well. Okay. So 1694 01:01:37,199 --> 01:01:41,598 here we will use this thing called um 1695 01:01:39,679 --> 01:01:43,759 the transformers library from hugging 1696 01:01:41,599 --> 01:01:45,599 face in particular the pipeline function 1697 01:01:43,760 --> 01:01:47,520 to demonstrate quickly how to do this 1698 01:01:45,599 --> 01:01:48,960 thing. Fortunately this library as of 1699 01:01:47,519 --> 01:01:50,000 this year is pre-installed in collab so 1700 01:01:48,960 --> 01:01:51,599 we can we don't have to install it. We 1701 01:01:50,000 --> 01:01:53,920 can just start using it right away. So 1702 01:01:51,599 --> 01:01:57,119 we'll take this example where you have a 1703 01:01:53,920 --> 01:01:59,039 bunch of text which says um 1704 01:01:57,119 --> 01:02:00,480 dear Amazon last week I got an Optimus 1705 01:01:59,039 --> 01:02:01,519 Prime action figure from your store in 1706 01:02:00,480 --> 01:02:04,000 Germany. Unfortunately when I opened the 1707 01:02:01,519 --> 01:02:05,039 vicage I discovered to my horror that I 1708 01:02:04,000 --> 01:02:06,719 had been sent an action figure of 1709 01:02:05,039 --> 01:02:08,639 Megatron instead. Can you imagine that 1710 01:02:06,719 --> 01:02:10,879 person's like sheer distress at this? 1711 01:02:08,639 --> 01:02:12,159 Um, so as a lifelong enemy of the 1712 01:02:10,880 --> 01:02:14,640 Decepticons, I hope you can understand 1713 01:02:12,159 --> 01:02:17,039 my dilemma. So to resolve the issue, I 1714 01:02:14,639 --> 01:02:19,440 demand an exchange. Encloser copies 1715 01:02:17,039 --> 01:02:21,039 expect to hear from you soon. Sincerely, 1716 01:02:19,440 --> 01:02:22,720 Bumblebee. 1717 01:02:21,039 --> 01:02:24,880 Okay, that Okay, they should have come 1718 01:02:22,719 --> 01:02:26,558 up with a better name for this example. 1719 01:02:24,880 --> 01:02:29,358 Uh, all right, cool. So that's the text 1720 01:02:26,559 --> 01:02:31,040 we have. So we import the this pipeline 1721 01:02:29,358 --> 01:02:33,119 function is the one that basically gives 1722 01:02:31,039 --> 01:02:34,558 you the ability to out of the box start 1723 01:02:33,119 --> 01:02:36,720 using it without any pre-training, 1724 01:02:34,559 --> 01:02:40,160 nothing like that. Okay, so we download 1725 01:02:36,719 --> 01:02:42,399 this thing. Um, oh wow, I got an A00 1726 01:02:40,159 --> 01:02:44,480 today. That happens very rarely. All 1727 01:02:42,400 --> 01:02:46,079 right, sorry. 1728 01:02:44,480 --> 01:02:48,000 So here, let's say you want to classify 1729 01:02:46,079 --> 01:02:50,000 that text. Okay, you want just want to 1730 01:02:48,000 --> 01:02:52,880 classify it for sentiment. You literally 1731 01:02:50,000 --> 01:02:55,358 go in there and say pipeline 1732 01:02:52,880 --> 01:02:57,599 text classification. That's the task you 1733 01:02:55,358 --> 01:02:59,519 want the pipeline to do for you, right? 1734 01:02:57,599 --> 01:03:01,280 And you create a classifier. Okay, it's 1735 01:02:59,519 --> 01:03:04,318 going to download a bunch of stuff. Uh, 1736 01:03:01,280 --> 01:03:06,079 and then so on and so forth. 1737 01:03:04,318 --> 01:03:08,558 The first time it just takes time to 1738 01:03:06,079 --> 01:03:10,240 download and then you literally take the 1739 01:03:08,559 --> 01:03:11,599 text you have here and then run it 1740 01:03:10,239 --> 01:03:14,078 through the classifier as it was just a 1741 01:03:11,599 --> 01:03:17,280 little function right you get some 1742 01:03:14,079 --> 01:03:19,599 outputs and then actually just do this 1743 01:03:17,280 --> 01:03:21,519 this way 1744 01:03:19,599 --> 01:03:23,760 negative sentiment is negative with 90% 1745 01:03:21,519 --> 01:03:25,838 probability pretty good right sequence 1746 01:03:23,760 --> 01:03:27,440 classification solved I mean sent 1747 01:03:25,838 --> 01:03:30,239 sentiment classification solved so we'll 1748 01:03:27,440 --> 01:03:31,838 try a few different examples uh I hated 1749 01:03:30,239 --> 01:03:33,038 the movie I if I said I loved the movie 1750 01:03:31,838 --> 01:03:34,880 I would be lying okay that's a little 1751 01:03:33,039 --> 01:03:36,400 tricky The movie left me speechless. 1752 01:03:34,880 --> 01:03:38,798 Incredible. And then I had to add this 1753 01:03:36,400 --> 01:03:40,400 last thing here last night. Almost but 1754 01:03:38,798 --> 01:03:42,000 not quite entirely unlike anything good 1755 01:03:40,400 --> 01:03:43,119 I've seen. Okay. And that's not 1756 01:03:42,000 --> 01:03:44,960 original. By the way, people who have 1757 01:03:43,119 --> 01:03:46,720 read Douglas Adams will know this famous 1758 01:03:44,960 --> 01:03:48,240 sentence about somebody drinking some 1759 01:03:46,719 --> 01:03:50,959 beverage and saying it's almost but not 1760 01:03:48,239 --> 01:03:52,558 quite entirely unlike tea. So I was 1761 01:03:50,960 --> 01:03:56,159 inspired by that. So anyway, we'll see 1762 01:03:52,559 --> 01:03:59,519 what happens. Um. 1763 01:03:56,159 --> 01:04:01,679 All right. Put it in there. Okay. So 1764 01:03:59,519 --> 01:04:02,960 negative. I hated the movie. Okay, fine. 1765 01:04:01,679 --> 01:04:05,038 If I said love me, I'd be lying. 1766 01:04:02,960 --> 01:04:07,440 Negative. Movie left me speechless. Uh, 1767 01:04:05,039 --> 01:04:09,119 it says it's negative, but it could go 1768 01:04:07,440 --> 01:04:09,838 either way, right? A good classifier 1769 01:04:09,119 --> 01:04:11,599 would have probably given you a 1770 01:04:09,838 --> 01:04:13,759 probability around the 50% mark because 1771 01:04:11,599 --> 01:04:15,760 it's sort of right on the fence. Um, 1772 01:04:13,760 --> 01:04:17,680 incredible, it's positive, and then it 1773 01:04:15,760 --> 01:04:20,640 got fooled by my crazy long sentence and 1774 01:04:17,679 --> 01:04:22,159 it says it's positive. Okay, now that's 1775 01:04:20,639 --> 01:04:23,679 classification. Here's one other quick 1776 01:04:22,159 --> 01:04:25,759 example. So, you can actually give it a 1777 01:04:23,679 --> 01:04:28,318 piece of text, right? For example, you 1778 01:04:25,760 --> 01:04:30,319 can take like a a Reuter's news story. 1779 01:04:28,318 --> 01:04:32,880 You can feed it and say extract all the 1780 01:04:30,318 --> 01:04:34,159 company names from it. Extract company 1781 01:04:32,880 --> 01:04:35,599 names, people names and things like 1782 01:04:34,159 --> 01:04:37,920 that. It's called named entity 1783 01:04:35,599 --> 01:04:40,240 extraction. And there are in the back in 1784 01:04:37,920 --> 01:04:42,400 back in the day people would bring they 1785 01:04:40,239 --> 01:04:44,479 would hand build painstakingly all these 1786 01:04:42,400 --> 01:04:46,079 very complex systems to be to do named 1787 01:04:44,480 --> 01:04:48,400 entity extraction. Now it's just a 1788 01:04:46,079 --> 01:04:50,559 pipeline away. So you can take this 1789 01:04:48,400 --> 01:04:53,280 thing and you can say create a pipeline 1790 01:04:50,559 --> 01:04:54,798 for any name extraction and for any 1791 01:04:53,280 --> 01:04:56,240 particular task that you're using there 1792 01:04:54,798 --> 01:04:57,838 might be a few additional parameters you 1793 01:04:56,239 --> 01:05:00,000 can set right as a part of the 1794 01:04:57,838 --> 01:05:03,000 configuration. So we download this 1795 01:05:00,000 --> 01:05:03,000 pipeline. 1796 01:05:08,480 --> 01:05:14,798 Okay, perfect. And then we run the 1797 01:05:11,199 --> 01:05:16,960 output. So it says okay good. Amazon is 1798 01:05:14,798 --> 01:05:18,559 an organization 1799 01:05:16,960 --> 01:05:21,119 uh 1800 01:05:18,559 --> 01:05:22,400 and Germany is a location lock which is 1801 01:05:21,119 --> 01:05:23,920 nice. So these things have a standard 1802 01:05:22,400 --> 01:05:24,798 vocabulary as to or lock things like 1803 01:05:23,920 --> 01:05:26,960 that which you can read up in the 1804 01:05:24,798 --> 01:05:29,599 documentation. Uh and then Bumblebee is 1805 01:05:26,960 --> 01:05:32,079 a person and then boy all the like the 1806 01:05:29,599 --> 01:05:33,760 Optimus Prime transformer stuff is all 1807 01:05:32,079 --> 01:05:36,480 it got full right. It thinks Optimus 1808 01:05:33,760 --> 01:05:38,000 Prime is miscellaneous. Uh decept is 1809 01:05:36,480 --> 01:05:39,039 miscellaneous and so on and so forth. 1810 01:05:38,000 --> 01:05:41,039 But you get the idea. You can take 1811 01:05:39,039 --> 01:05:42,400 standard things like Reuters use stories 1812 01:05:41,039 --> 01:05:44,160 and so or you can just boop. You can get 1813 01:05:42,400 --> 01:05:45,440 a very good entity extraction right out 1814 01:05:44,159 --> 01:05:47,038 of the bat. And once you get these 1815 01:05:45,440 --> 01:05:48,960 entities extracted, then you can put 1816 01:05:47,039 --> 01:05:50,640 them into a nice structured data table 1817 01:05:48,960 --> 01:05:53,280 like a database and then you can run 1818 01:05:50,639 --> 01:05:55,679 traditional machine learning on it. 1819 01:05:53,280 --> 01:05:58,559 Okay. Um and then I had I think a few 1820 01:05:55,679 --> 01:06:01,598 more examples of question answering and 1821 01:05:58,559 --> 01:06:02,798 uh actually let's just try that. um you 1822 01:06:01,599 --> 01:06:03,920 can actually give it a thing and ask a 1823 01:06:02,798 --> 01:06:07,599 question about it and you can actually 1824 01:06:03,920 --> 01:06:09,119 give you the answer which gets into the 1825 01:06:07,599 --> 01:06:10,960 causal transformer thing that we're 1826 01:06:09,119 --> 01:06:12,720 going to see on Monday which builds up 1827 01:06:10,960 --> 01:06:14,480 into large language models because you 1828 01:06:12,719 --> 01:06:16,000 obviously can give something you can 1829 01:06:14,480 --> 01:06:17,440 give a passage to chat GPT and ask a 1830 01:06:16,000 --> 01:06:19,599 question ask it to give you an answer so 1831 01:06:17,440 --> 01:06:20,880 it's really in that thing but um just 1832 01:06:19,599 --> 01:06:25,280 for fun let's just do that to see if 1833 01:06:20,880 --> 01:06:27,440 it's any good um okay so what does the 1834 01:06:25,280 --> 01:06:29,359 customer want and the output is an 1835 01:06:27,440 --> 01:06:32,480 exchange of megatron and it's telling 1836 01:06:29,358 --> 01:06:34,558 you which where it starts in the text 1837 01:06:32,480 --> 01:06:37,599 and where it ends the relevant passage. 1838 01:06:34,559 --> 01:06:39,119 It's pretty good, right? So because 1839 01:06:37,599 --> 01:06:41,200 remember if you have stuff like this 1840 01:06:39,119 --> 01:06:42,559 then when you ask like a large language 1841 01:06:41,199 --> 01:06:44,399 model a question it gives you an answer. 1842 01:06:42,559 --> 01:06:46,720 You can actually ask it to give you 1843 01:06:44,400 --> 01:06:48,480 exactly where in the input it found the 1844 01:06:46,719 --> 01:06:49,679 answer and because you know these things 1845 01:06:48,480 --> 01:06:51,920 are going to elicitate you can actually 1846 01:06:49,679 --> 01:06:54,078 look at the input that it's claiming to 1847 01:06:51,920 --> 01:06:56,318 use and look at what it says and see if 1848 01:06:54,079 --> 01:06:59,760 they actually match. It's a way to sort 1849 01:06:56,318 --> 01:07:01,920 of essentially do QA on LLM output. 1850 01:06:59,760 --> 01:07:03,280 Um okay so that's what we have here and 1851 01:07:01,920 --> 01:07:05,280 I have other budget much of which which 1852 01:07:03,280 --> 01:07:07,760 I'll ignore for the moment because I 1853 01:07:05,280 --> 01:07:10,000 want to go back to the PowerPoint. 1854 01:07:07,760 --> 01:07:11,760 So yeah so if you have a standard task 1855 01:07:10,000 --> 01:07:13,679 uh you know you can just use pipelines 1856 01:07:11,760 --> 01:07:15,200 and hugging face to actually solve many 1857 01:07:13,679 --> 01:07:18,078 of them out of the box without any heavy 1858 01:07:15,199 --> 01:07:19,598 lifting. So I mentioned earlier on that 1859 01:07:18,079 --> 01:07:21,519 transformers have proven to be effective 1860 01:07:19,599 --> 01:07:24,079 for a whole bunch of domains outside of 1861 01:07:21,519 --> 01:07:26,480 natural language processing um like you 1862 01:07:24,079 --> 01:07:29,119 know speech recognition, computer vision 1863 01:07:26,480 --> 01:07:30,559 and so on and so forth. Um and so I want 1864 01:07:29,119 --> 01:07:32,400 to give you a couple of quick examples 1865 01:07:30,559 --> 01:07:35,280 of how to think about transform using 1866 01:07:32,400 --> 01:07:39,358 transformers for non-ext applications. 1867 01:07:35,280 --> 01:07:41,039 Okay. So uh the the key insight here is 1868 01:07:39,358 --> 01:07:42,880 that the architecture of the transformer 1869 01:07:41,039 --> 01:07:45,280 block that we have looked at amazingly 1870 01:07:42,880 --> 01:07:47,599 enough can be used as is with no changes 1871 01:07:45,280 --> 01:07:49,519 no surgery needed. No clever thinking 1872 01:07:47,599 --> 01:07:51,359 required for any particular application. 1873 01:07:49,519 --> 01:07:53,759 What is needed where the clever thinking 1874 01:07:51,358 --> 01:07:55,358 may be required is you need to take the 1875 01:07:53,760 --> 01:07:57,280 inputs that you're working with and you 1876 01:07:55,358 --> 01:07:59,679 need to figure out a way to tokenize and 1877 01:07:57,280 --> 01:08:01,039 encode them into embeddings 1878 01:07:59,679 --> 01:08:03,358 which can then be sent into the 1879 01:08:01,039 --> 01:08:05,839 transformer. So all the action is in 1880 01:08:03,358 --> 01:08:07,759 taking that input that non-ext input and 1881 01:08:05,838 --> 01:08:09,759 figuring out a way to cast them in the 1882 01:08:07,760 --> 01:08:12,640 language of embeddings. That's where the 1883 01:08:09,760 --> 01:08:14,160 that's the game. Okay. So um here is 1884 01:08:12,639 --> 01:08:16,158 something called the vision transformer 1885 01:08:14,159 --> 01:08:19,119 which is very famous actually. I think 1886 01:08:16,158 --> 01:08:20,559 it may be the first perhaps the first uh 1887 01:08:19,119 --> 01:08:23,358 transformer architecture that was 1888 01:08:20,560 --> 01:08:25,759 applied to vision problems. So um so 1889 01:08:23,359 --> 01:08:28,960 let's say you have a picture um yeah so 1890 01:08:25,759 --> 01:08:31,679 let's say you have this picture okay 1891 01:08:28,960 --> 01:08:33,279 it is just a picture okay so you have to 1892 01:08:31,679 --> 01:08:35,679 find a way to create embeddings from 1893 01:08:33,279 --> 01:08:38,000 this picture or to tokenize this picture 1894 01:08:35,679 --> 01:08:40,158 in some way with sentences you know I 1895 01:08:38,000 --> 01:08:41,759 love hard well obviously I love and hard 1896 01:08:40,158 --> 01:08:43,599 are three tokens it's pretty trivial to 1897 01:08:41,759 --> 01:08:45,359 figure out how to tokenize them but with 1898 01:08:43,600 --> 01:08:47,120 a picture what do you do right it's kind 1899 01:08:45,359 --> 01:08:49,600 of weird to think of tokenizing a 1900 01:08:47,119 --> 01:08:51,119 picture so what these people did is that 1901 01:08:49,600 --> 01:08:52,960 they say you know what I'm going to take 1902 01:08:51,119 --> 01:08:54,479 this picture and chop it up into small 1903 01:08:52,960 --> 01:08:57,039 squares. 1904 01:08:54,479 --> 01:08:58,639 Right? So in this example, they have 1905 01:08:57,039 --> 01:09:02,079 taken this big picture and chopped it up 1906 01:08:58,640 --> 01:09:03,359 into nine little pictures. Okay? Then 1907 01:09:02,079 --> 01:09:05,278 you can take each of those nine 1908 01:09:03,359 --> 01:09:07,600 pictures. 1909 01:09:05,279 --> 01:09:09,679 Each of those nine pictures, right? If 1910 01:09:07,600 --> 01:09:11,600 you look at the how it's represented, 1911 01:09:09,679 --> 01:09:15,440 it's just three tables of numbers, 1912 01:09:11,600 --> 01:09:16,960 right? The RGB values, right? So you can 1913 01:09:15,439 --> 01:09:20,318 take all those numbers and you just 1914 01:09:16,960 --> 01:09:22,880 create a giant long vector from it. 1915 01:09:20,319 --> 01:09:26,080 Okay? you have a huge long vector and 1916 01:09:22,880 --> 01:09:28,719 then you run it through a dense layer to 1917 01:09:26,079 --> 01:09:30,079 come up with a smaller vector 1918 01:09:28,719 --> 01:09:31,838 and that smaller vector is your 1919 01:09:30,079 --> 01:09:34,318 embedding. 1920 01:09:31,838 --> 01:09:36,079 That's it. But the way you transform the 1921 01:09:34,319 --> 01:09:37,600 long vector into small vector is just a 1922 01:09:36,079 --> 01:09:39,119 dense layer whose weights can be 1923 01:09:37,600 --> 01:09:41,039 learned. 1924 01:09:39,119 --> 01:09:42,559 So what these people did is they said 1925 01:09:41,039 --> 01:09:44,560 well I'm going to first chop it up into 1926 01:09:42,560 --> 01:09:47,199 these patches and then I take each patch 1927 01:09:44,560 --> 01:09:49,039 and do a linear projection. Right? A 1928 01:09:47,198 --> 01:09:50,639 flattened patch is nothing more than a 1929 01:09:49,039 --> 01:09:52,079 three tables of numbers flattened into a 1930 01:09:50,640 --> 01:09:54,480 long vector. That's what the word 1931 01:09:52,079 --> 01:09:56,000 flatten here means. And once you flatten 1932 01:09:54,479 --> 01:09:58,158 it, I'm just going to run it through a 1933 01:09:56,000 --> 01:09:59,760 dense layer. So, by the way, you will 1934 01:09:58,158 --> 01:10:01,279 see the words linear projection. It's a 1935 01:09:59,760 --> 01:10:03,360 synonym for run it through a dense 1936 01:10:01,279 --> 01:10:05,198 layer. 1937 01:10:03,359 --> 01:10:08,000 So, you run it through a dense layer, 1938 01:10:05,198 --> 01:10:09,599 right? You get these nice vectors, these 1939 01:10:08,000 --> 01:10:11,520 vectors. 1940 01:10:09,600 --> 01:10:12,880 And now you say, well, you know what? I 1941 01:10:11,520 --> 01:10:15,120 have to take the order of these things 1942 01:10:12,880 --> 01:10:17,039 into account because clearly this little 1943 01:10:15,119 --> 01:10:18,479 patch is in the top left while this 1944 01:10:17,039 --> 01:10:20,640 patch is somewhere in the middle. Right? 1945 01:10:18,479 --> 01:10:22,079 The order matters in the picture 1946 01:10:20,640 --> 01:10:24,239 otherwise every jumbled version is going 1947 01:10:22,079 --> 01:10:26,158 to be the same thing. So you use 1948 01:10:24,238 --> 01:10:27,519 positional embeddings 1949 01:10:26,158 --> 01:10:31,439 you basically say there are nine 1950 01:10:27,520 --> 01:10:33,760 positions in any picture right 0 1 2 3 4 1951 01:10:31,439 --> 01:10:36,879 5 6 7 8 there are nine positions. So I'm 1952 01:10:33,760 --> 01:10:39,199 going to create nine position embeddings 1953 01:10:36,880 --> 01:10:40,319 and then I'm just going to add them up. 1954 01:10:39,198 --> 01:10:41,519 Then I'm just going to add them up to 1955 01:10:40,319 --> 01:10:44,319 this embedding. Just like we did with 1956 01:10:41,520 --> 01:10:45,440 words. With words, we each word had an 1957 01:10:44,319 --> 01:10:47,119 embedding. Each position had an 1958 01:10:45,439 --> 01:10:49,359 embedding. We added them up. Here each 1959 01:10:47,119 --> 01:10:50,800 image has an embedding. The position of 1960 01:10:49,359 --> 01:10:53,439 the little patch in the picture has an 1961 01:10:50,800 --> 01:10:54,960 embedding. We add them up. Okay? And 1962 01:10:53,439 --> 01:10:57,198 then because we want to use it for 1963 01:10:54,960 --> 01:11:00,239 classification, no problem. We'll have a 1964 01:10:57,198 --> 01:11:01,678 little CLS token 1965 01:11:00,238 --> 01:11:04,399 and then we just run it through the 1966 01:11:01,679 --> 01:11:06,480 transformer. That's it. 1967 01:11:04,399 --> 01:11:08,238 and then you get the CLS token and then 1968 01:11:06,479 --> 01:11:09,759 you can attach a softmax to it and say, 1969 01:11:08,238 --> 01:11:12,079 "Okay, it's a bird, it's a ball, it's a 1970 01:11:09,760 --> 01:11:14,560 car. 1971 01:11:12,079 --> 01:11:16,960 That's it. This simple approach actually 1972 01:11:14,560 --> 01:11:19,440 works 1973 01:11:16,960 --> 01:11:22,158 amazingly enough." 1974 01:11:19,439 --> 01:11:23,439 Okay, so that is the vision transformer 1975 01:11:22,158 --> 01:11:24,639 and I'm going through it fast just to 1976 01:11:23,439 --> 01:11:29,359 give you a sense for how these things 1977 01:11:24,640 --> 01:11:31,760 work. Uh any questions? Yeah. Uh my 1978 01:11:29,359 --> 01:11:33,759 question is like uh in case of uh text 1979 01:11:31,760 --> 01:11:35,280 we had fixed number of tokens that is 1980 01:11:33,760 --> 01:11:37,360 amount of words which could be there in 1981 01:11:35,279 --> 01:11:39,359 your vocabul in the English vocabulary 1982 01:11:37,359 --> 01:11:41,279 but here if you look at images they will 1983 01:11:39,359 --> 01:11:43,519 probably go into trillions that I know 1984 01:11:41,279 --> 01:11:45,599 like we are not talking about one image 1985 01:11:43,520 --> 01:11:47,760 but we take a total set of plot of 1986 01:11:45,600 --> 01:11:52,079 images and we try to subset each one of 1987 01:11:47,760 --> 01:11:53,679 them each one would have its own uh uh 1988 01:11:52,079 --> 01:11:56,158 own weights like own parameters. There 1989 01:11:53,679 --> 01:11:58,719 is no notion of vocabulary here. All 1990 01:11:56,158 --> 01:12:02,319 we're saying is that given any image, we 1991 01:11:58,719 --> 01:12:03,920 create nine patches, sub images from it. 1992 01:12:02,319 --> 01:12:06,880 Each of those patches gets passed 1993 01:12:03,920 --> 01:12:09,440 through a dense layer and out comes an 1994 01:12:06,880 --> 01:12:10,800 embedding. So at that point, any image 1995 01:12:09,439 --> 01:12:13,119 you give me, I'm going to give get you 1996 01:12:10,800 --> 01:12:14,719 nine embeddings out of it. And once I 1997 01:12:13,119 --> 01:12:16,000 get the nine embeddings, I just throw it 1998 01:12:14,719 --> 01:12:19,239 into the meat grinder, the transformer 1999 01:12:16,000 --> 01:12:19,238 meat grinder. 2000 01:12:20,079 --> 01:12:25,198 All right. So uh another example I think 2001 01:12:23,760 --> 01:12:27,600 some of you have asked me outside of 2002 01:12:25,198 --> 01:12:30,238 class um how good are transformers for 2003 01:12:27,600 --> 01:12:32,480 structured data tabular data right for 2004 01:12:30,238 --> 01:12:34,399 tabular data in general um things like 2005 01:12:32,479 --> 01:12:36,238 xg boost gradient boosting works really 2006 01:12:34,399 --> 01:12:38,238 really well so it's good to try them 2007 01:12:36,238 --> 01:12:39,519 certainly I don't think transformers and 2008 01:12:38,238 --> 01:12:42,639 deep learning networks have any great 2009 01:12:39,520 --> 01:12:44,400 edge over xg boost for structured data 2010 01:12:42,640 --> 01:12:46,480 problems so it's worth trying both of 2011 01:12:44,399 --> 01:12:48,719 them however you can use transformers 2012 01:12:46,479 --> 01:12:50,238 for this stuff too so that's called the 2013 01:12:48,719 --> 01:12:52,158 tab transformer one of the first ones 2014 01:12:50,238 --> 01:12:54,158 wants to come out a transform of a 2015 01:12:52,158 --> 01:12:56,799 tabular data and again it's pretty 2016 01:12:54,158 --> 01:12:58,719 simple. All you do is 2017 01:12:56,800 --> 01:13:00,640 in any kind of input that you have, you 2018 01:12:58,719 --> 01:13:02,560 will have some categorical variables, 2019 01:13:00,640 --> 01:13:04,640 right? Like blood pressure, things like 2020 01:13:02,560 --> 01:13:07,440 that, right? Not blood pressure, bad 2021 01:13:04,640 --> 01:13:10,079 example, gender, right? Um, and so on 2022 01:13:07,439 --> 01:13:12,000 and so forth. And so what you do is you 2023 01:13:10,079 --> 01:13:14,640 take all the categorical features and 2024 01:13:12,000 --> 01:13:16,640 for each categorical feature, you create 2025 01:13:14,640 --> 01:13:18,480 embeddings 2026 01:13:16,640 --> 01:13:20,640 because a categorical feature is just 2027 01:13:18,479 --> 01:13:22,399 text. 2028 01:13:20,640 --> 01:13:23,920 A categorical feature is just text. So 2029 01:13:22,399 --> 01:13:27,920 you can create text embeddings for it. 2030 01:13:23,920 --> 01:13:30,000 No problem. Um, 2031 01:13:27,920 --> 01:13:32,800 and you take all the continuous 2032 01:13:30,000 --> 01:13:34,640 features, right? Cholesterol and blood 2033 01:13:32,800 --> 01:13:36,560 pressure and whatnot, right? To go to 2034 01:13:34,640 --> 01:13:38,560 the heart disease example, and then you 2035 01:13:36,560 --> 01:13:39,840 take just create all the correct them 2036 01:13:38,560 --> 01:13:41,840 all and just create a vector out of 2037 01:13:39,840 --> 01:13:45,199 them. 2038 01:13:41,840 --> 01:13:47,279 You're just a vector. Okay? Then you run 2039 01:13:45,198 --> 01:13:48,960 these the embeddings for all the 2040 01:13:47,279 --> 01:13:51,599 categorical variables through a nice 2041 01:13:48,960 --> 01:13:52,880 transformer block. And you can see here 2042 01:13:51,600 --> 01:13:54,960 it's exactly the block we have seen 2043 01:13:52,880 --> 01:13:56,319 before. no difference. And then at the 2044 01:13:54,960 --> 01:13:58,000 very end when it comes out of the 2045 01:13:56,319 --> 01:13:59,279 transformer, you take all the contextual 2046 01:13:58,000 --> 01:14:01,119 stuff coming out of the transformer and 2047 01:13:59,279 --> 01:14:03,519 then you concatenate it with the 2048 01:14:01,119 --> 01:14:05,198 continuous features. 2049 01:14:03,520 --> 01:14:07,120 Okay. And then you run it through maybe 2050 01:14:05,198 --> 01:14:09,519 one or more dense layers and boom 2051 01:14:07,119 --> 01:14:11,198 output. 2052 01:14:09,520 --> 01:14:12,880 So this is a tab tabular data 2053 01:14:11,198 --> 01:14:14,399 transformer. And there are many you know 2054 01:14:12,880 --> 01:14:16,960 refinements improvements over the years 2055 01:14:14,399 --> 01:14:18,879 that have come since then. But the key 2056 01:14:16,960 --> 01:14:21,840 thing I want you to rec remember from 2057 01:14:18,880 --> 01:14:24,159 here is that categorical variables can 2058 01:14:21,840 --> 01:14:28,800 be very easily represented as 2059 01:14:24,158 --> 01:14:31,039 embeddings. That's the key. Okay. Uh all 2060 01:14:28,800 --> 01:14:32,480 right. So that's that. Now once the 2061 01:14:31,039 --> 01:14:34,479 input has been transformed into sort of 2062 01:14:32,479 --> 01:14:35,839 this common language of embeddings, we 2063 01:14:34,479 --> 01:14:37,279 can process them without changing the 2064 01:14:35,840 --> 01:14:39,600 architecture of the block itself because 2065 01:14:37,279 --> 01:14:40,960 all it wants is embeddings. It's like 2066 01:14:39,600 --> 01:14:42,159 you give me embeddings, I give you a 2067 01:14:40,960 --> 01:14:44,399 great contextual embeddings out and 2068 01:14:42,158 --> 01:14:47,439 nobody gets hurt, right? That is the 2069 01:14:44,399 --> 01:14:50,639 deal with the transformer stack. So um 2070 01:14:47,439 --> 01:14:52,559 now this this ability this sort of since 2071 01:14:50,640 --> 01:14:54,480 the transformer is agnostic to the kind 2072 01:14:52,560 --> 01:14:56,640 of input as long as it comes into comes 2073 01:14:54,479 --> 01:14:58,799 in as a form of an embedding you can use 2074 01:14:56,640 --> 01:15:00,159 it for multimodal data very easily. So 2075 01:14:58,800 --> 01:15:02,079 for example let's say that you have a 2076 01:15:00,158 --> 01:15:03,759 problem in which you have a picture that 2077 01:15:02,079 --> 01:15:05,760 you have to be sent in some text that 2078 01:15:03,760 --> 01:15:08,560 goes in a bunch of tabular data coming 2079 01:15:05,760 --> 01:15:10,079 in well you take the text and do 2080 01:15:08,560 --> 01:15:11,520 language embeddings like we know how to 2081 01:15:10,079 --> 01:15:12,640 do you take the image and do image 2082 01:15:11,520 --> 01:15:14,640 embeddings like we just saw with the 2083 01:15:12,640 --> 01:15:16,320 vision transformer. You take tablet data 2084 01:15:14,640 --> 01:15:18,719 and do tab data embeddings like we saw 2085 01:15:16,319 --> 01:15:21,840 with the tab transformer. Once we do it, 2086 01:15:18,719 --> 01:15:23,439 it's all a bunch of embeddings 2087 01:15:21,840 --> 01:15:25,199 and then you attach a little class token 2088 01:15:23,439 --> 01:15:27,839 on top, send it through a bunch of 2089 01:15:25,198 --> 01:15:29,839 transformers blocks and then out comes a 2090 01:15:27,840 --> 01:15:32,319 contextual class token the contextual 2091 01:15:29,840 --> 01:15:36,000 version run it through maybe a sigmoid 2092 01:15:32,319 --> 01:15:38,079 or a softmax predict the label done. 2093 01:15:36,000 --> 01:15:40,960 So this is extremely powerful its 2094 01:15:38,079 --> 01:15:42,880 ability to handle multimodel data. Okay. 2095 01:15:40,960 --> 01:15:46,079 And that's why for example if you look 2096 01:15:42,880 --> 01:15:48,400 at Gemini Google Gemini 1.5 Pro GPT4 2097 01:15:46,079 --> 01:15:50,559 vision and so on you can send it images 2098 01:15:48,399 --> 01:15:53,599 and a question and you'll get an answer 2099 01:15:50,560 --> 01:15:55,840 back because every modality that goes in 2100 01:15:53,600 --> 01:15:58,880 is cast into embeddings and once it's 2101 01:15:55,840 --> 01:16:00,159 embedded one once it's embeddingized 2102 01:15:58,880 --> 01:16:02,079 then the transformer doesn't care. It'll 2103 01:16:00,158 --> 01:16:04,238 just do its thing. 2104 01:16:02,079 --> 01:16:06,479 It it will decide for example that this 2105 01:16:04,238 --> 01:16:09,678 word in your question actually is highly 2106 01:16:06,479 --> 01:16:12,479 related to that patch in the picture. 2107 01:16:09,679 --> 01:16:14,640 Right? you'll just figure it out. 2108 01:16:12,479 --> 01:16:16,879 Uh, okay. That's all I had because 2109 01:16:14,640 --> 01:16:18,320 there's a time pering 9:55. Perfect. All 2110 01:16:16,880 --> 01:16:21,640 right, folks. Thanks. Have a great rest 2111 01:16:18,319 --> 01:16:21,639 of your week.