1 00:00:05,330 --> 00:00:07,099 So, let's get started. 2 00:00:07,099 --> 00:00:10,039 So I'll be talking about building LLMs today. 3 00:00:10,039 --> 00:00:14,389 So I think a lot of you have heard of LLMs before, but just 4 00:00:14,390 --> 00:00:16,190 as a quick recap. 5 00:00:16,190 --> 00:00:18,679 LLMs standing for large language models 6 00:00:18,679 --> 00:00:21,109 are basically all the chat bots that you've 7 00:00:21,109 --> 00:00:22,859 been hearing about recently. 8 00:00:22,859 --> 00:00:28,519 So, ChatGPT, from OpenAI, Claude, from Anthropic, Gemini 9 00:00:28,519 --> 00:00:31,259 and Llama, and other types of models like this. 10 00:00:31,260 --> 00:00:34,228 And today we'll be talking about how do they actually work. 11 00:00:34,228 --> 00:00:36,770 So it's going to be an overview because it's only one lecture 12 00:00:36,770 --> 00:00:38,312 and it's hard to compress everything. 13 00:00:38,311 --> 00:00:39,949 But hopefully, I'll touch a little bit 14 00:00:39,950 --> 00:00:41,617 about all the components that are needed 15 00:00:41,616 --> 00:00:43,909 to train some of these LLMs. 16 00:00:43,909 --> 00:00:46,309 Also, if you have questions, please interrupt me 17 00:00:46,310 --> 00:00:48,900 and ask if you have a question. 18 00:00:48,899 --> 00:00:52,699 Most likely other people in the room or on Zoom have other. 19 00:00:52,700 --> 00:00:53,760 Have the same questions. 20 00:00:53,759 --> 00:00:56,379 So, please ask. 21 00:00:56,380 --> 00:00:56,880 Great. 22 00:00:56,880 --> 00:01:00,080 So what matters when training LLMs. 23 00:01:00,079 --> 00:01:02,780 So there are a few key components that matter. 24 00:01:02,780 --> 00:01:04,109 One is the architecture. 25 00:01:04,109 --> 00:01:07,390 So as you probably all LLMs are neural networks, 26 00:01:07,390 --> 00:01:09,388 and when you think about neural networks, 27 00:01:09,388 --> 00:01:11,680 you have to think about what architecture you're using. 28 00:01:11,680 --> 00:01:13,770 And another component, which is really important 29 00:01:13,769 --> 00:01:16,920 is the training loss and the training algorithm. 30 00:01:16,920 --> 00:01:20,590 So, how you actually train these models, then it's data. 31 00:01:20,590 --> 00:01:24,420 So, what do you train these models on. 32 00:01:24,420 --> 00:01:26,280 The evaluation, which is how do you 33 00:01:26,280 --> 00:01:28,590 know whether you're actually making progress 34 00:01:28,590 --> 00:01:33,460 towards the goal of LLMs and then, the system component. 35 00:01:33,459 --> 00:01:35,189 So that is like how do you actually 36 00:01:35,189 --> 00:01:38,622 make these models run on modern hardware, which 37 00:01:38,623 --> 00:01:41,040 is really important because these models are really large. 38 00:01:41,040 --> 00:01:43,620 So now more than ever, systems are actually 39 00:01:43,620 --> 00:01:47,160 really an important topic for LLMs. 40 00:01:47,159 --> 00:01:52,109 So those five components, you probably all know that LLMs. 41 00:01:52,109 --> 00:01:53,879 And if you don't know LLMs are all 42 00:01:53,879 --> 00:01:56,009 based on transformers or at least some version 43 00:01:56,010 --> 00:01:57,510 of transformers. 44 00:01:57,510 --> 00:02:00,880 I'm actually not going to talk about the architecture today. 45 00:02:00,879 --> 00:02:06,329 One, because I gave a lecture on transformers a few weeks ago 46 00:02:06,329 --> 00:02:09,210 and two, because you can find so much information online 47 00:02:09,210 --> 00:02:11,400 on transformers. 48 00:02:11,400 --> 00:02:14,689 There's much less information about the other four topics. 49 00:02:14,689 --> 00:02:17,370 So, I really want to talk about those. 50 00:02:17,370 --> 00:02:20,189 And another thing to say is that most of academia 51 00:02:20,189 --> 00:02:22,979 actually focuses on architecture and training 52 00:02:22,979 --> 00:02:25,799 algorithm and losses as academics 53 00:02:25,800 --> 00:02:28,810 and I've done that for a big part of my career, 54 00:02:28,810 --> 00:02:32,670 is simply we like thinking that this is like we make 55 00:02:32,669 --> 00:02:35,129 new architectures, new models, and it 56 00:02:35,129 --> 00:02:37,030 seems like it's very important. 57 00:02:37,030 --> 00:02:39,960 But in reality, honestly, what matters in practice is mostly 58 00:02:39,960 --> 00:02:41,710 the three other topics. 59 00:02:41,710 --> 00:02:45,629 So, data, evaluation and systems, which is what most 60 00:02:45,629 --> 00:02:48,293 of industry actually focuses on. 61 00:02:48,293 --> 00:02:49,710 So, that's also one of the reasons 62 00:02:49,710 --> 00:02:52,085 why I don't want to talk too much about the architecture, 63 00:02:52,085 --> 00:02:55,060 because really the rest is super important. 64 00:02:55,060 --> 00:02:55,560 Great. 65 00:02:55,560 --> 00:02:57,449 So, overview of the lecture, I'll 66 00:02:57,449 --> 00:02:58,689 be talking about pretraining. 67 00:02:58,689 --> 00:03:00,879 So, pretraining, you probably heard that word. 68 00:03:00,879 --> 00:03:02,229 This is the general word. 69 00:03:02,229 --> 00:03:06,449 This is kind of the classical language modeling paradigm where 70 00:03:06,449 --> 00:03:08,939 you basically train your language model to essentially 71 00:03:08,939 --> 00:03:10,469 model all of internet. 72 00:03:10,469 --> 00:03:11,987 And then, there's a post training, 73 00:03:11,987 --> 00:03:13,530 which is a more recent paradigm which 74 00:03:13,530 --> 00:03:15,300 is taking these large language models 75 00:03:15,300 --> 00:03:18,060 and making them essentially AI assistants. 76 00:03:18,060 --> 00:03:22,259 So, this is more of a recent trend since ChatGPT. 77 00:03:22,259 --> 00:03:25,090 So, if you ever heard of GPT3 or GPT2, 78 00:03:25,090 --> 00:03:27,300 that's really pretraining land. 79 00:03:27,300 --> 00:03:29,830 If you heard of ChatGPT, which you probably have, 80 00:03:29,830 --> 00:03:31,980 this is really post training land, 81 00:03:31,979 --> 00:03:34,949 so I'll be talking about both, but I'll start with pretraining 82 00:03:34,949 --> 00:03:37,469 and specifically I'll talk about what 83 00:03:37,469 --> 00:03:41,159 is the task of pretraining LLMs and what is the loss that people 84 00:03:41,159 --> 00:03:43,109 actually use. 85 00:03:43,110 --> 00:03:47,130 So, language modeling, this is a quick recap. 86 00:03:47,129 --> 00:03:49,349 Language models at a high level are simply 87 00:03:49,349 --> 00:03:52,259 models of probability distribution over sequences 88 00:03:52,259 --> 00:03:53,709 of tokens or of words. 89 00:03:53,710 --> 00:03:57,390 So it's basically some model of p of x1 90 00:03:57,389 --> 00:03:59,729 to XL, where x1 is basically what 91 00:03:59,729 --> 00:04:04,229 one and XL is the last one in the sequence or in the sentence. 92 00:04:04,229 --> 00:04:07,259 So, very concretely, if you have a sentence like the mouse 93 00:04:07,259 --> 00:04:09,629 ate the cheese, what the language model gives 94 00:04:09,629 --> 00:04:13,889 you is simply a probability of this sentence being uttered 95 00:04:13,889 --> 00:04:17,189 by a human or being found online. 96 00:04:17,189 --> 00:04:21,778 So, if you have another sentence like "The the mouse ate cheese." 97 00:04:21,778 --> 00:04:23,740 Here, there's grammatical mistakes. 98 00:04:23,740 --> 00:04:25,800 So, the model should know that this should 99 00:04:25,800 --> 00:04:27,520 have some syntactic knowledge. 100 00:04:27,519 --> 00:04:30,120 So, it should know that this has less likelihood 101 00:04:30,120 --> 00:04:32,459 of appearing online. 102 00:04:32,459 --> 00:04:36,539 If you have another sentence like the cheese ate the mouse, 103 00:04:36,540 --> 00:04:39,390 then the model should hopefully know about the fact 104 00:04:39,389 --> 00:04:42,029 that usually cheese don't eat mouse. 105 00:04:42,029 --> 00:04:43,489 So, there's some semantic knowledge 106 00:04:43,490 --> 00:04:45,490 and this is less likely that the first sentence. 107 00:04:45,490 --> 00:04:50,007 So, this is basically at a high level what language models are. 108 00:04:50,007 --> 00:04:52,590 One word that you probably have been hearing a lot in the news 109 00:04:52,589 --> 00:04:54,119 are generative models. 110 00:04:54,120 --> 00:04:56,250 So, this is just something that can generate. 111 00:04:56,250 --> 00:04:57,870 Models that can generate sentences 112 00:04:57,870 --> 00:04:59,372 or can generate some data. 113 00:04:59,372 --> 00:05:01,830 The reason why we say language models are generative models 114 00:05:01,829 --> 00:05:04,479 is that once you have a model of a distribution, 115 00:05:04,480 --> 00:05:06,160 you can simply sample from this model. 116 00:05:06,160 --> 00:05:07,950 And now we can generate data. 117 00:05:07,949 --> 00:05:12,269 So we can generate sentences using a language model. 118 00:05:12,269 --> 00:05:15,659 So the type of models that people are all currently using 119 00:05:15,660 --> 00:05:18,900 are what we call autoregressive language models. 120 00:05:18,899 --> 00:05:21,929 And the key idea of autoregressive language models 121 00:05:21,930 --> 00:05:25,319 is that you take this distribution over words 122 00:05:25,319 --> 00:05:29,490 and you basically decompose it into the distribution 123 00:05:29,490 --> 00:05:32,910 of the first word, multiply by the distribution of 124 00:05:32,910 --> 00:05:35,370 or the likelihood of the distribution of the second word 125 00:05:35,370 --> 00:05:37,530 given the first word, and multiply it 126 00:05:37,529 --> 00:05:40,979 by P of the third word given the first two words. 127 00:05:40,980 --> 00:05:42,462 So, there's no approximation here. 128 00:05:42,461 --> 00:05:44,670 This is just the chain rule of probability, which you 129 00:05:44,670 --> 00:05:46,230 hopefully you all know about. 130 00:05:46,230 --> 00:05:47,350 Really no approximation. 131 00:05:47,350 --> 00:05:50,655 This is just one way of modeling a distribution. 132 00:05:50,654 --> 00:05:52,529 So, slightly more concisely, you can write it 133 00:05:52,529 --> 00:05:57,299 as a product of P's of the next word, given everything which 134 00:05:57,300 --> 00:05:58,240 happened in the past. 135 00:05:58,240 --> 00:05:59,639 So, of the context. 136 00:05:59,639 --> 00:06:02,680 So, this is what we call autoregressive language models. 137 00:06:02,680 --> 00:06:05,009 Again, this is really not the only way 138 00:06:05,009 --> 00:06:06,430 of modeling distribution. 139 00:06:06,430 --> 00:06:07,980 This is just one way. 140 00:06:07,980 --> 00:06:10,430 It has some benefits and some downsides. 141 00:06:10,430 --> 00:06:12,840 One downside of autoregressive language models 142 00:06:12,839 --> 00:06:15,232 is that when you actually sample from this autoregressive 143 00:06:15,233 --> 00:06:16,649 language model, you basically have 144 00:06:16,649 --> 00:06:20,310 a for loop, which generates the next word, then conditions 145 00:06:20,310 --> 00:06:21,430 on that next word. 146 00:06:21,430 --> 00:06:23,050 And then we generate in other words. 147 00:06:23,050 --> 00:06:24,990 So, basically if you have a longer sentence 148 00:06:24,990 --> 00:06:28,259 that you want to generate, it takes more time to generate it. 149 00:06:28,259 --> 00:06:31,000 So, there are some downsides of this current paradigm, 150 00:06:31,000 --> 00:06:33,009 but that's what we currently have. 151 00:06:33,009 --> 00:06:36,089 So, I'm going to talk about this one. 152 00:06:36,089 --> 00:06:36,599 Great. 153 00:06:36,600 --> 00:06:38,310 So, autoregressive language models. 154 00:06:38,310 --> 00:06:41,879 At a high level, what a task of autoregressive language model 155 00:06:41,879 --> 00:06:44,290 is simply predicting the next word, as I just said. 156 00:06:44,290 --> 00:06:47,040 So, if we have a sentence like she likely prefers, 157 00:06:47,040 --> 00:06:50,310 one potential, next word might be dogs. 158 00:06:50,310 --> 00:06:54,449 And the way we do it is that we first tokenize. 159 00:06:54,449 --> 00:06:58,259 So, you take these words or subwords you tokenize them 160 00:06:58,259 --> 00:07:00,849 and then you give an ID for each token. 161 00:07:00,850 --> 00:07:03,060 So here you have one, two, three. 162 00:07:03,060 --> 00:07:04,793 Then, you pass it through this black box. 163 00:07:04,793 --> 00:07:06,209 As I already said, we're not going 164 00:07:06,209 --> 00:07:07,501 to talk about the architecture. 165 00:07:07,502 --> 00:07:10,000 You just pass it through, pass it through a model, 166 00:07:10,000 --> 00:07:13,740 and you then get a distribution, a probability distribution 167 00:07:13,740 --> 00:07:16,590 over the next word or over the next token. 168 00:07:16,589 --> 00:07:20,139 And then you sample from this distribution, 169 00:07:20,139 --> 00:07:22,959 you get a new token and then you detokenize. 170 00:07:22,959 --> 00:07:24,989 So, you get a new ID, you detokenize 171 00:07:24,990 --> 00:07:28,199 and that's how you basically sample from a language model. 172 00:07:28,199 --> 00:07:29,699 One thing which is important to note 173 00:07:29,699 --> 00:07:32,099 is that the last two steps are actually 174 00:07:32,100 --> 00:07:34,290 only needed during inference. 175 00:07:34,290 --> 00:07:36,000 When you do training, you just need 176 00:07:36,000 --> 00:07:38,610 to predict the most likely token and you can just 177 00:07:38,610 --> 00:07:41,530 compare to the real token which happened in practice, 178 00:07:41,529 --> 00:07:43,829 and then, you basically change the weights 179 00:07:43,829 --> 00:07:46,379 of your model to increase the probability of generating 180 00:07:46,379 --> 00:07:46,894 that token. 181 00:07:49,500 --> 00:07:50,009 Great. 182 00:07:50,009 --> 00:07:52,449 So, autoregressive neural language models. 183 00:07:52,449 --> 00:07:54,159 So to be slightly more specific, still, 184 00:07:54,160 --> 00:07:56,010 without talking about the architecture, 185 00:07:56,009 --> 00:07:58,889 the first thing we do is that we have all of these. 186 00:07:58,889 --> 00:07:59,509 Sorry, yes. 187 00:07:59,509 --> 00:08:01,524 On the previous slide. 188 00:08:01,524 --> 00:08:03,399 Predicting the probability of the next token, 189 00:08:03,399 --> 00:08:06,029 does this mean that your final output vector has 190 00:08:06,029 --> 00:08:08,726 to be the same dimensionality as the number of tokens 191 00:08:08,726 --> 00:08:09,310 that you have? 192 00:08:09,310 --> 00:08:10,439 Yes. 193 00:08:10,439 --> 00:08:13,480 How do you deal with if you have more token. 194 00:08:13,480 --> 00:08:16,030 Adding more token to your [INAUDIBLE]? 195 00:08:16,029 --> 00:08:18,489 Yeah so we're going to talk about tokenization 196 00:08:18,490 --> 00:08:21,530 actually later so you will get some sense of this. 197 00:08:21,529 --> 00:08:24,919 You basically can deal with adding new tokens. 198 00:08:24,920 --> 00:08:25,990 I'm kind of exaggerating. 199 00:08:25,990 --> 00:08:28,240 There are methods for doing it, but essentially people 200 00:08:28,240 --> 00:08:29,500 don't do it. 201 00:08:29,500 --> 00:08:32,110 So it's really important to think about 202 00:08:32,110 --> 00:08:33,860 how you tokenize your text, and that's why 203 00:08:33,860 --> 00:08:35,259 we'll talk about that later. 204 00:08:35,259 --> 00:08:36,788 But it's a very good point to note 205 00:08:36,788 --> 00:08:38,620 is that you basically-- the vocabulary size, so 206 00:08:38,620 --> 00:08:40,662 the number of tokens that you have is essentially 207 00:08:40,662 --> 00:08:43,220 the output of your language model. 208 00:08:43,220 --> 00:08:46,000 So it's actually pretty large. 209 00:08:46,000 --> 00:08:48,490 So autoregressive neural language models. 210 00:08:48,490 --> 00:08:51,730 First thing you do is that you take every word or every token. 211 00:08:51,730 --> 00:08:56,080 You embed them so you get some vector representation 212 00:08:56,080 --> 00:08:58,129 for each of these tokens. 213 00:08:58,129 --> 00:09:00,379 You pass them through some neural network, as we said, 214 00:09:00,379 --> 00:09:01,309 it's a transformer. 215 00:09:01,309 --> 00:09:04,629 Then you get a representation for all the word 216 00:09:04,629 --> 00:09:06,513 and all the words in the context. 217 00:09:06,513 --> 00:09:07,930 So it's basically a representation 218 00:09:07,929 --> 00:09:09,789 of the entire sentence. 219 00:09:09,789 --> 00:09:11,659 You pass it through a linear layer, 220 00:09:11,659 --> 00:09:15,809 as you just said, to basically map it to the number 221 00:09:15,809 --> 00:09:17,719 so that the output-- the number of outputs 222 00:09:17,720 --> 00:09:19,519 is the number of tokens. 223 00:09:19,519 --> 00:09:21,559 You then pass it through some softmax 224 00:09:21,559 --> 00:09:24,619 and you basically get a probability distribution 225 00:09:24,620 --> 00:09:30,289 over the next words given every word in the context. 226 00:09:30,289 --> 00:09:32,750 And the last that you use is basically-- 227 00:09:32,750 --> 00:09:35,370 it's essentially a task of classifying the next token. 228 00:09:35,370 --> 00:09:37,620 So it's a very simple, kind of, machine learning task. 229 00:09:37,620 --> 00:09:39,139 So you use the cross-entropy loss. 230 00:09:39,139 --> 00:09:44,144 Where you basically look at the actual target that happened, 231 00:09:44,144 --> 00:09:45,769 which is the target distribution, which 232 00:09:45,769 --> 00:09:49,049 is a one hot encoding, which in this case says, 233 00:09:49,049 --> 00:09:51,899 I saw the real word that happened is cat. 234 00:09:51,899 --> 00:09:55,620 So that's a one hot distribution over cat. 235 00:09:55,620 --> 00:09:57,522 And here this is the actual-- 236 00:09:57,522 --> 00:09:58,355 do you see my mouse? 237 00:09:58,355 --> 00:09:58,759 Oh, yeah. 238 00:09:58,759 --> 00:10:00,569 This is the distribution that you generated. 239 00:10:00,570 --> 00:10:01,950 And basically you do cross entropy, 240 00:10:01,950 --> 00:10:04,492 which really just increases the probability of generating cat 241 00:10:04,491 --> 00:10:06,769 and decreases all the probability of generating 242 00:10:06,769 --> 00:10:08,029 all the other tokens. 243 00:10:08,029 --> 00:10:11,539 One thing to notice is that, as you all know again, 244 00:10:11,539 --> 00:10:15,860 this is just equivalent to maximizing the text log 245 00:10:15,860 --> 00:10:17,960 likelihood because you can just rewrite 246 00:10:17,960 --> 00:10:23,180 the max over the probability of this autoregressive language 247 00:10:23,179 --> 00:10:26,899 modeling task as just being this minimum of I just 248 00:10:26,899 --> 00:10:29,042 added the log here and minus, which 249 00:10:29,042 --> 00:10:31,750 is just the minimum of the loss, which is the cross entropy loss. 250 00:10:31,750 --> 00:10:33,330 So basically minimizing the loss is 251 00:10:33,330 --> 00:10:36,750 the same thing as maximizing the likelihood of your text. 252 00:10:36,750 --> 00:10:37,980 Any question? 253 00:10:37,980 --> 00:10:38,759 Questions? 254 00:10:43,230 --> 00:10:46,879 OK, tokenizer. 255 00:10:46,879 --> 00:10:49,399 So this is one thing that people usually 256 00:10:49,399 --> 00:10:50,909 don't talk that much about. 257 00:10:50,909 --> 00:10:53,480 Tokenizers are extremely important. 258 00:10:53,480 --> 00:10:56,539 So it's really important that you understand at least what 259 00:10:56,539 --> 00:10:57,819 they do at a high level. 260 00:10:57,820 --> 00:11:01,040 So why do we need tokenizers in the first place? 261 00:11:01,039 --> 00:11:02,969 First, it's more general than words. 262 00:11:02,970 --> 00:11:04,820 So one simple thing that you might think 263 00:11:04,820 --> 00:11:07,379 is we're just going to take every word that we will have. 264 00:11:07,379 --> 00:11:11,059 You just say every word is a token in its own. 265 00:11:11,059 --> 00:11:14,489 But then what happens is if there's a typo in your word? 266 00:11:14,490 --> 00:11:17,389 Then you might not have any token associated 267 00:11:17,389 --> 00:11:20,009 with this word with a typo. 268 00:11:20,009 --> 00:11:21,860 And then you don't know how to actually pass 269 00:11:21,860 --> 00:11:24,460 this word with a typo into a large language model. 270 00:11:24,460 --> 00:11:25,710 So what do you do next? 271 00:11:25,710 --> 00:11:29,470 And also, even if you think about words, words is a very-- 272 00:11:29,470 --> 00:11:32,210 words are fine with Latin-based languages. 273 00:11:32,210 --> 00:11:34,610 But if you think about a language like Thai, 274 00:11:34,610 --> 00:11:36,769 you won't have a simple way of tokenizing 275 00:11:36,769 --> 00:11:39,500 by spaces because there are no spaces between words. 276 00:11:39,500 --> 00:11:43,269 So really, tokens are much more general than words. 277 00:11:43,269 --> 00:11:44,319 It's the first thing. 278 00:11:44,320 --> 00:11:45,695 Second thing that you might think 279 00:11:45,695 --> 00:11:48,660 is that you might tokenize every sentence, character 280 00:11:48,659 --> 00:11:49,500 by character. 281 00:11:49,500 --> 00:11:52,649 You might say A is one token, B is another token. 282 00:11:52,649 --> 00:11:55,360 That would actually work and probably very well. 283 00:11:55,360 --> 00:11:58,360 The issue is that then your sequence becomes super long. 284 00:11:58,360 --> 00:12:00,600 And as you probably remember from the lecture 285 00:12:00,600 --> 00:12:05,399 on transformers, the complexity grows quadratically 286 00:12:05,399 --> 00:12:06,819 with the length of sequences. 287 00:12:06,820 --> 00:12:10,050 So you really don't want to have a super-long sequence. 288 00:12:10,049 --> 00:12:14,609 So tokenizers basically try to deal with those two problems 289 00:12:14,610 --> 00:12:19,330 and give common subsequences a certain token. 290 00:12:19,330 --> 00:12:22,530 And usually how you should be thinking about it is around 291 00:12:22,529 --> 00:12:27,579 an average of every token is around 3-4 letters. 292 00:12:27,580 --> 00:12:30,153 And there are many algorithms for tokenization. 293 00:12:30,153 --> 00:12:32,820 I'll just talk about one of them to give you a high level, which 294 00:12:32,820 --> 00:12:34,660 is what we call Byte Pair Encoding, which is actually 295 00:12:34,659 --> 00:12:35,399 a pretty common. 296 00:12:35,399 --> 00:12:37,750 One of the two most common tokenizers. 297 00:12:37,750 --> 00:12:39,750 And the way that you train a tokenizer 298 00:12:39,750 --> 00:12:42,572 is that first you start with a very large corpus of text. 299 00:12:42,572 --> 00:12:45,240 And here, I'm really not talking about training a large language 300 00:12:45,240 --> 00:12:48,000 model yet, this is purely for the tokenization step. 301 00:12:48,000 --> 00:12:52,049 So this is my large corpus of text with these five words. 302 00:12:52,049 --> 00:12:55,469 And then you associate every character 303 00:12:55,470 --> 00:12:58,769 in this corpus of text a different token. 304 00:12:58,769 --> 00:13:00,569 So here, I just split it up every character 305 00:13:00,570 --> 00:13:03,060 with a different token, and I just 306 00:13:03,059 --> 00:13:05,759 color coded all of those tokens. 307 00:13:05,759 --> 00:13:08,159 And then what you do is that you go through your text, 308 00:13:08,159 --> 00:13:12,519 and every time you see pairs of tokens that are very common, 309 00:13:12,519 --> 00:13:15,309 the most common pair of token, you just merge them. 310 00:13:15,309 --> 00:13:19,859 So here you see three times the tokens t and o 311 00:13:19,860 --> 00:13:20,830 next to each other. 312 00:13:20,830 --> 00:13:22,830 So you're just going to say this is a new token. 313 00:13:22,830 --> 00:13:24,460 And then you continue, you repeat that. 314 00:13:24,460 --> 00:13:28,509 So now you have tok, tok which happens three times. 315 00:13:28,509 --> 00:13:33,730 Toke with an E that happens 2 times and token, 316 00:13:33,730 --> 00:13:37,149 which happens twice, and then ex which also happens twice. 317 00:13:37,149 --> 00:13:41,370 So this is the-- if you were to train a tokenizer on this corpus 318 00:13:41,370 --> 00:13:43,289 of text, which is very small, that's 319 00:13:43,289 --> 00:13:45,000 how you would finish with a token-- 320 00:13:45,000 --> 00:13:47,580 with like trained tokenizer. 321 00:13:47,580 --> 00:13:51,600 In reality, you do it on much larger corpus of text. 322 00:13:51,600 --> 00:13:54,810 And this is the real tokenizer of-- 323 00:13:54,809 --> 00:13:57,839 actually, I think this is GPT3 or ChatGPT. 324 00:13:57,840 --> 00:14:00,460 And here you see how it would actually separate these words. 325 00:14:00,460 --> 00:14:01,918 So basically you see the same thing 326 00:14:01,918 --> 00:14:03,909 as what we gave in the previous example. 327 00:14:03,909 --> 00:14:06,459 Token becomes its own token. 328 00:14:06,460 --> 00:14:08,850 So tokenizer is actually split it up 329 00:14:08,850 --> 00:14:12,659 into two tokens token and -izer. 330 00:14:12,659 --> 00:14:15,100 So yeah, that's all about tokenizers. 331 00:14:15,100 --> 00:14:16,200 Any questions on that? 332 00:14:16,200 --> 00:14:16,710 Yeah. 333 00:14:16,710 --> 00:14:18,502 How do you deal with spaces, and how do you 334 00:14:18,501 --> 00:14:19,799 deal with [INAUDIBLE]. 335 00:14:19,799 --> 00:14:23,559 Yeah so actually there's a step before tokenizers, 336 00:14:23,559 --> 00:14:25,709 which is what we call pre-tokenizers, which 337 00:14:25,710 --> 00:14:27,960 is exactly what you just said. 338 00:14:27,960 --> 00:14:29,460 So this is mostly-- 339 00:14:29,460 --> 00:14:33,540 in theory, there's no reason to deal with spaces and punctuation 340 00:14:33,539 --> 00:14:34,389 separately. 341 00:14:34,389 --> 00:14:37,029 You could just say every space gets its own token, 342 00:14:37,029 --> 00:14:40,620 every punctuation gets its own token, 343 00:14:40,620 --> 00:14:42,350 and you can just do all the merging. 344 00:14:42,350 --> 00:14:45,009 The problem is that-- so there's an efficiency question. 345 00:14:45,009 --> 00:14:48,120 Actually, training these tokenizers takes a long time. 346 00:14:48,120 --> 00:14:51,879 So you better-- because you have to consider every pair of token. 347 00:14:51,879 --> 00:14:54,200 So what you end up doing is saying if there's a space, 348 00:14:54,200 --> 00:14:55,710 this is very-- like pre-tokenizers 349 00:14:55,710 --> 00:14:57,100 are very English specific. 350 00:14:57,100 --> 00:14:58,620 You say if there's a space, we're 351 00:14:58,620 --> 00:15:01,409 not going to start looking at the token that came before 352 00:15:01,409 --> 00:15:03,250 and the token that came afterwards. 353 00:15:03,250 --> 00:15:06,070 So you're not merging in between spaces. 354 00:15:06,070 --> 00:15:10,060 But this is just like a computational optimization. 355 00:15:10,059 --> 00:15:12,629 You could theoretically just deal with it 356 00:15:12,629 --> 00:15:15,159 the same way as you deal with any other character. 357 00:15:15,159 --> 00:15:15,659 And-- 358 00:15:15,659 --> 00:15:16,370 Yeah. 359 00:15:16,370 --> 00:15:19,750 When you merge tokens to delete the tokens that you merged away 360 00:15:19,750 --> 00:15:22,950 or do you keep the smaller tokens that emerge? 361 00:15:22,950 --> 00:15:25,360 You actually keep the smaller tokens. 362 00:15:25,360 --> 00:15:29,850 I mean, in reality, it doesn't matter much because usually 363 00:15:29,850 --> 00:15:32,909 on a large corpus of text, you will have actually everything. 364 00:15:32,909 --> 00:15:34,629 But you usually keep the small ones. 365 00:15:34,629 --> 00:15:36,212 And the reason why you want to do that 366 00:15:36,212 --> 00:15:38,969 is because if-- in case there's, as we said before, you have 367 00:15:38,970 --> 00:15:41,759 some grammatical mistakes or some typos, 368 00:15:41,759 --> 00:15:43,379 you still want to be able to represent 369 00:15:43,379 --> 00:15:46,559 these words by character. 370 00:15:46,559 --> 00:15:47,729 So, yeah. 371 00:15:47,730 --> 00:15:48,810 Yes. 372 00:15:48,809 --> 00:15:51,039 Are the tokens unique? 373 00:15:51,039 --> 00:15:54,990 So I mean, say in this case T-O-K-E-N is there only one 374 00:15:54,990 --> 00:15:56,129 occurrence or could-- 375 00:15:56,129 --> 00:16:00,120 do you need to leave multiple occurrence so they could have-- 376 00:16:00,120 --> 00:16:02,039 take on different meanings or something? 377 00:16:02,039 --> 00:16:03,230 Oh I see what you say. 378 00:16:03,230 --> 00:16:08,399 No, it's every token has its own unique ID. 379 00:16:08,399 --> 00:16:11,049 So a usual-- this is a great question. 380 00:16:11,049 --> 00:16:13,349 For example, if you think about a bank, which 381 00:16:13,350 --> 00:16:16,200 could be bank for like money or bank like water, 382 00:16:16,200 --> 00:16:18,009 it will have the same token. 383 00:16:18,009 --> 00:16:19,919 But the model will learn, the transformer 384 00:16:19,919 --> 00:16:22,750 will learn that based on the words that are around it, 385 00:16:22,750 --> 00:16:24,840 it should associate that-- 386 00:16:24,840 --> 00:16:26,590 I'm saying-- I'm being very handwavy here, 387 00:16:26,590 --> 00:16:30,420 but associate that with a representation that 388 00:16:30,419 --> 00:16:33,959 is either more like the bank money side or the bank water 389 00:16:33,960 --> 00:16:34,703 side. 390 00:16:34,702 --> 00:16:36,370 But that's a transformer that does that. 391 00:16:36,370 --> 00:16:38,019 It's not a tokenizer. 392 00:16:38,019 --> 00:16:39,059 Yes. 393 00:16:39,059 --> 00:16:39,559 Yes. 394 00:16:39,559 --> 00:16:41,119 So you mentioned during tokenization, 395 00:16:41,120 --> 00:16:43,210 keep the smaller tokens you started with, right. 396 00:16:43,210 --> 00:16:45,970 Like if you start with a T you keep the T 397 00:16:45,970 --> 00:16:47,800 and then you build your tokenize out to 398 00:16:47,799 --> 00:16:49,569 [INAUDIBLE] allow input token. 399 00:16:49,570 --> 00:16:53,110 So let's say maybe you didn't train on token, but in your data 400 00:16:53,110 --> 00:16:54,970 you are trying to encode token. 401 00:16:54,970 --> 00:16:58,970 So how does the tokenizer know to encode it with token or to 402 00:16:58,970 --> 00:16:59,470 [INAUDIBLE]? 403 00:16:59,470 --> 00:16:59,889 Yeah. 404 00:16:59,889 --> 00:17:00,682 The great question. 405 00:17:00,682 --> 00:17:02,889 You basically when you-- so when you tokenize, 406 00:17:02,889 --> 00:17:04,598 so that's after training of the tokenizer 407 00:17:04,598 --> 00:17:06,549 when you actually apply the tokenizer 408 00:17:06,549 --> 00:17:10,088 you basically always choose the largest token 409 00:17:10,088 --> 00:17:11,440 that you can apply. 410 00:17:11,440 --> 00:17:13,640 So if you can do token, you will never do T, 411 00:17:13,640 --> 00:17:15,910 you will always do token. 412 00:17:15,910 --> 00:17:18,220 But there's actually-- so people don't usually 413 00:17:18,220 --> 00:17:20,588 talk that much about tokenizers, but there's 414 00:17:20,588 --> 00:17:24,490 a lot of computational benefits or computational tricks 415 00:17:24,490 --> 00:17:27,190 that you can do for making these things faster. 416 00:17:27,190 --> 00:17:29,160 So I really don't think we-- and honestly, I 417 00:17:29,160 --> 00:17:31,493 think a lot of people think that we should just get away 418 00:17:31,492 --> 00:17:34,450 from tokenizers and just kind of tokenize character 419 00:17:34,450 --> 00:17:36,860 by character or bytes by bytes. 420 00:17:36,859 --> 00:17:39,709 But as I said, right now there's this issue of length, 421 00:17:39,710 --> 00:17:42,019 but maybe one day, like in five or 10 years, 422 00:17:42,019 --> 00:17:43,519 we will have different architectures 423 00:17:43,519 --> 00:17:46,144 that don't scale quadratically with the length of the sequence. 424 00:17:46,144 --> 00:17:50,910 And maybe we'll move away from tokenizers. 425 00:17:50,910 --> 00:17:53,029 So can you share with us the drawback? 426 00:17:53,029 --> 00:17:57,470 Why do people want to move away from the tokenizer? 427 00:17:57,470 --> 00:17:58,140 Yeah. 428 00:17:58,140 --> 00:18:03,350 So I think one good example is math. 429 00:18:03,349 --> 00:18:06,109 If you think about math, actually numbers right now 430 00:18:06,109 --> 00:18:07,229 are not tokenized. 431 00:18:07,230 --> 00:18:10,640 So for example, 327 might have its own token, which 432 00:18:10,640 --> 00:18:13,200 means that models, when they see numbers, 433 00:18:13,200 --> 00:18:15,509 they don't see them the same way as we do. 434 00:18:15,509 --> 00:18:17,640 And this is very annoying because I mean, 435 00:18:17,640 --> 00:18:19,820 the reason why we can generalize with math 436 00:18:19,819 --> 00:18:22,579 is because we can deal with every letter separately 437 00:18:22,579 --> 00:18:24,289 and we can then do composition. 438 00:18:24,289 --> 00:18:26,309 Where you know that basically if you add stuff, 439 00:18:26,309 --> 00:18:28,879 it's the same thing as adding every one separately 440 00:18:28,880 --> 00:18:30,920 plus like whatever the unit that you add. 441 00:18:30,920 --> 00:18:32,570 So they can't do that. 442 00:18:32,569 --> 00:18:35,179 So then you have to do special tokenization. 443 00:18:35,180 --> 00:18:39,650 And, like, one of the big changes that GPT4 did 444 00:18:39,650 --> 00:18:42,990 is changing the way that they tokenize code. 445 00:18:42,990 --> 00:18:46,099 So for example, if you have code, you know you have often, 446 00:18:46,099 --> 00:18:48,169 in Python, these four spaces at the beginning. 447 00:18:48,170 --> 00:18:52,259 Those were dealt with strangely before. 448 00:18:52,259 --> 00:18:54,289 And as a result, like, the model couldn't really 449 00:18:54,289 --> 00:18:57,869 understand how to deal with code. 450 00:18:57,869 --> 00:19:00,829 So tokenize actually matter a lot. 451 00:19:00,829 --> 00:19:04,189 OK, so I'll move on right now, but we can come back later 452 00:19:04,190 --> 00:19:05,870 on tokenizers. 453 00:19:05,869 --> 00:19:06,509 Great. 454 00:19:06,509 --> 00:19:08,819 So we talked about a task the loss the tokenizer, 455 00:19:08,819 --> 00:19:11,480 let's talk a little bit about evaluation. 456 00:19:11,480 --> 00:19:13,640 So the way that LLMs are usually evaluated 457 00:19:13,640 --> 00:19:16,910 is what we call-- is using what we call perplexity. 458 00:19:16,910 --> 00:19:20,029 At a high level it's basically just your validation loss. 459 00:19:20,029 --> 00:19:21,980 The slight difference with perplexity 460 00:19:21,980 --> 00:19:24,569 is that we use something that is slightly more interpretable, 461 00:19:24,569 --> 00:19:27,710 which is that we use the average per token loss, 462 00:19:27,710 --> 00:19:29,366 and then you exponentiate it. 463 00:19:29,366 --> 00:19:30,950 And the reason why you exponentiate it 464 00:19:30,950 --> 00:19:32,370 is because you want-- 465 00:19:32,369 --> 00:19:35,311 I mean, the loss has a log inside and you-- 466 00:19:35,311 --> 00:19:36,769 like one humans are actually pretty 467 00:19:36,769 --> 00:19:38,099 bad at thinking in log space. 468 00:19:38,099 --> 00:19:41,119 But two logs depend on the base of the log 469 00:19:41,119 --> 00:19:44,059 while when you exponentiate you basically have everything 470 00:19:44,059 --> 00:19:48,440 in the vocabulary size unit. 471 00:19:48,440 --> 00:19:50,299 And the average per token is just so 472 00:19:50,299 --> 00:19:52,909 that your perplexity is independent of the length 473 00:19:52,910 --> 00:19:54,170 of your sequence. 474 00:19:54,170 --> 00:19:57,380 So perplexity is just two to the power average 475 00:19:57,380 --> 00:20:00,050 of the loss of the sequence. 476 00:20:00,049 --> 00:20:04,399 So perplexity is between one and the length of the vocabulary 477 00:20:04,400 --> 00:20:05,780 of your tokenizer. 478 00:20:05,779 --> 00:20:08,359 One it's simply well, if you predict perfectly 479 00:20:08,359 --> 00:20:11,569 the thing which every word, then every word 480 00:20:11,569 --> 00:20:14,629 will have basically products of ones. 481 00:20:14,630 --> 00:20:16,680 So the best perplexity you can have is one. 482 00:20:16,680 --> 00:20:18,799 If you really have no idea, you basically 483 00:20:18,799 --> 00:20:22,204 predict with one divided by size of vocabulary 484 00:20:22,204 --> 00:20:24,079 and then you do simple math and you basically 485 00:20:24,079 --> 00:20:26,750 get perplexity of size of vocabulary. 486 00:20:26,750 --> 00:20:28,519 So the intuition of perplexity is 487 00:20:28,519 --> 00:20:30,200 that it's basically the number of tokens 488 00:20:30,200 --> 00:20:32,809 that your model is, kind of, hesitating between. 489 00:20:32,809 --> 00:20:35,609 So if your model is perfect, it doesn't hesitate. 490 00:20:35,609 --> 00:20:36,719 It know exactly the word. 491 00:20:36,720 --> 00:20:38,779 If it really has no idea, then it 492 00:20:38,779 --> 00:20:43,730 hesitates between all of the vocabulary. 493 00:20:43,730 --> 00:20:46,289 So perplexity really improved. 494 00:20:46,289 --> 00:20:50,750 That's perplexity on a standard data set between 2017 and 2023. 495 00:20:50,750 --> 00:20:54,980 It went from a kind of 70 tokens to less than 10 tokens 496 00:20:54,980 --> 00:20:56,610 over these five, six years. 497 00:20:56,609 --> 00:20:58,879 So that means that the models were previously 498 00:20:58,880 --> 00:21:02,550 stated between 70 words every time it was generating a word, 499 00:21:02,549 --> 00:21:05,250 and now it's hesitating between less than 10 words. 500 00:21:05,250 --> 00:21:06,859 So that's much better. 501 00:21:06,859 --> 00:21:08,839 Perplexity is actually not used anymore 502 00:21:08,839 --> 00:21:11,209 in academic benchmarking, mostly because it depends 503 00:21:11,210 --> 00:21:12,950 on the tokenizer that you use. 504 00:21:12,950 --> 00:21:16,170 It depends on the actual data that people are evaluating on. 505 00:21:16,170 --> 00:21:19,200 But it's still very important for development of LLMs. 506 00:21:19,200 --> 00:21:21,740 So when you actually train your own LLM people 507 00:21:21,740 --> 00:21:26,029 will still really look at the perplexity. 508 00:21:26,029 --> 00:21:30,259 One common other way and now more common in academia 509 00:21:30,259 --> 00:21:34,640 of evaluating these LLMs is just by taking all the classical NLP 510 00:21:34,640 --> 00:21:37,340 benchmarks, and I'll give you a few examples later and just, 511 00:21:37,339 --> 00:21:39,259 kind of, aggregating everything. 512 00:21:39,259 --> 00:21:43,099 So collect as many automatically evaluatable benchmarks 513 00:21:43,099 --> 00:21:46,250 and just evaluate across all of them. 514 00:21:46,250 --> 00:21:50,240 So one such-- or actually two such 515 00:21:50,240 --> 00:21:54,059 benchmarks are what we call HELM, which is from Stanford. 516 00:21:54,059 --> 00:21:56,639 And another one is the Hugging Face open leaderboard, 517 00:21:56,640 --> 00:22:00,080 which are probably the two most common ones right now. 518 00:22:00,079 --> 00:22:02,899 So just to give you an idea, in HELM, 519 00:22:02,900 --> 00:22:04,910 all of these type of tasks, which 520 00:22:04,910 --> 00:22:08,390 are mostly things that can be easily evaluated 521 00:22:08,390 --> 00:22:09,840 like question answering. 522 00:22:09,839 --> 00:22:13,339 So think about many different question answering tasks. 523 00:22:13,339 --> 00:22:15,349 And the benefit with question answering 524 00:22:15,349 --> 00:22:18,319 is that you usually know what is the real answer. 525 00:22:18,319 --> 00:22:20,509 So you can-- the way that you evaluate these models 526 00:22:20,509 --> 00:22:22,970 and I'll give you a concrete example in one second, 527 00:22:22,970 --> 00:22:26,870 is that you can just look at how likely the language model is 528 00:22:26,869 --> 00:22:30,302 to generate the real answer compared to some other answers. 529 00:22:30,303 --> 00:22:31,970 And that's essentially, at a high level, 530 00:22:31,970 --> 00:22:33,799 how you evaluate these models. 531 00:22:33,799 --> 00:22:35,759 So to give you a specific example, 532 00:22:35,759 --> 00:22:42,000 MMLU is probably the most common academic benchmark for LLMs. 533 00:22:42,000 --> 00:22:45,720 And this is just a collection of many question 534 00:22:45,720 --> 00:22:47,620 and answers in all of those domains. 535 00:22:47,619 --> 00:22:50,379 For example, college medicine, college physics, 536 00:22:50,380 --> 00:22:52,660 astronomy and these type of topics. 537 00:22:52,660 --> 00:22:55,390 And the questions are things like, so this is in astronomy. 538 00:22:55,390 --> 00:22:58,300 What is true for type-1a supernova? 539 00:22:58,299 --> 00:23:01,769 Then you give four different potential answers 540 00:23:01,769 --> 00:23:04,839 and you just ask the model which one is more likely. 541 00:23:04,839 --> 00:23:06,879 So there are many different ways of doing it. 542 00:23:06,880 --> 00:23:09,180 Either you can look at the likelihood of generating 543 00:23:09,180 --> 00:23:11,670 all these answers, or you can ask the model 544 00:23:11,670 --> 00:23:12,878 which one is the most likely. 545 00:23:12,877 --> 00:23:15,294 So there are different ways that you can prompt the model, 546 00:23:15,295 --> 00:23:17,620 but at a high level, you know which one is correct. 547 00:23:17,619 --> 00:23:20,039 And there are three other mistakes. 548 00:23:20,039 --> 00:23:22,200 Yes. 549 00:23:22,200 --> 00:23:24,910 Creating unconstrained text as an output. 550 00:23:24,910 --> 00:23:25,620 Yeah. 551 00:23:25,619 --> 00:23:28,019 How do you evaluate a model if it 552 00:23:28,019 --> 00:23:31,410 gives something that's semantically completely 553 00:23:31,410 --> 00:23:35,500 identical, but is not the exact tokens that you expect? 554 00:23:35,500 --> 00:23:36,000 Yeah. 555 00:23:36,000 --> 00:23:37,390 So that's a great question. 556 00:23:37,390 --> 00:23:38,880 I'll talk more about that later. 557 00:23:38,880 --> 00:23:41,340 Here, in this case, we don't do unconstrained. 558 00:23:41,339 --> 00:23:44,669 So the way you would evaluate MMLU is basically either 559 00:23:44,670 --> 00:23:47,400 you ask the first question, and then you 560 00:23:47,400 --> 00:23:50,220 look at the likelihood of the model generating A, 561 00:23:50,220 --> 00:23:53,605 the likelihood of the model generating B, C, and D 562 00:23:53,605 --> 00:23:55,480 and you look at which one is the most likely. 563 00:23:55,480 --> 00:23:58,349 Or you can ask the model out of A, B, C, D, 564 00:23:58,349 --> 00:23:59,859 which one is the most likely. 565 00:23:59,859 --> 00:24:03,069 And you look at whether the most likely next token is A, B, 566 00:24:03,069 --> 00:24:05,759 C, or D. So you constrain the model 567 00:24:05,759 --> 00:24:09,000 to say it can only answer these four things. 568 00:24:09,000 --> 00:24:10,380 You say you constraint-- 569 00:24:10,380 --> 00:24:11,460 Yeah. 570 00:24:11,460 --> 00:24:13,090 You constrain the prompt or do you 571 00:24:13,089 --> 00:24:15,240 mean of its whole probability distribution 572 00:24:15,240 --> 00:24:17,430 that it outputs you only comparing 573 00:24:17,430 --> 00:24:19,900 the outputs of like-- you're only comparing the A token the 574 00:24:19,900 --> 00:24:20,400 [INAUDIBLE]. 575 00:24:20,400 --> 00:24:20,900 Yeah. 576 00:24:20,900 --> 00:24:24,430 So in the second case I gave you, you would do exactly the-- 577 00:24:24,430 --> 00:24:25,450 actually would do both. 578 00:24:25,450 --> 00:24:27,408 You would prompt the model saying A, B, C, or D 579 00:24:27,407 --> 00:24:32,049 plus you would constrain to only look at these four tokens. 580 00:24:32,049 --> 00:24:34,690 In the first case, you don't even need to generate anything. 581 00:24:34,690 --> 00:24:36,356 So in the first case, you literally just 582 00:24:36,356 --> 00:24:38,049 look, given it's a language model, 583 00:24:38,049 --> 00:24:40,240 it can give a distribution over sentences. 584 00:24:40,240 --> 00:24:43,529 You just look at what is the likelihood of generating 585 00:24:43,529 --> 00:24:45,009 all of these words? 586 00:24:45,009 --> 00:24:48,279 What is the likelihood of generating the second choice? 587 00:24:48,279 --> 00:24:52,619 And you just look at whether the most likely sentence is actually 588 00:24:52,619 --> 00:24:54,239 the real answer. 589 00:24:54,240 --> 00:24:56,470 So you don't actually sample from it, 590 00:24:56,470 --> 00:24:59,519 you really just use P of X1 to XL. 591 00:24:59,519 --> 00:25:01,379 Does that make sense? 592 00:25:01,380 --> 00:25:05,035 That being said, evaluation of open-ended questions 593 00:25:05,035 --> 00:25:06,910 is something we're going to talk about later, 594 00:25:06,910 --> 00:25:08,326 and it's actually really important 595 00:25:08,326 --> 00:25:09,720 and really challenging. 596 00:25:09,720 --> 00:25:10,930 Yes. 597 00:25:10,930 --> 00:25:13,840 Earlier you mentioned [INAUDIBLE] metrics 598 00:25:13,839 --> 00:25:16,740 like perplexity are not I usually 599 00:25:16,740 --> 00:25:18,690 use because it depends on how you do 600 00:25:18,690 --> 00:25:21,029 your tokenization, some design choices. 601 00:25:21,029 --> 00:25:24,480 I was wondering if you could speak more to that. 602 00:25:24,480 --> 00:25:25,150 Yeah. 603 00:25:25,150 --> 00:25:26,830 So think about perplexity. 604 00:25:26,829 --> 00:25:30,129 I told you perplexity is between 1 and vocabulary size. 605 00:25:30,130 --> 00:25:34,710 So now imagine that ChatGPT uses a tokenizer that has 10,000 606 00:25:34,710 --> 00:25:38,340 tokens but Gemini from Google uses a tokenizer that had 607 00:25:38,339 --> 00:25:41,679 100,000 potential tokens. 608 00:25:41,680 --> 00:25:45,870 Then actually the Gemini one will have the upper bound 609 00:25:45,869 --> 00:25:48,989 of the perplexity that you can get is actually worse for Gemini 610 00:25:48,990 --> 00:25:50,940 than for ChatGPT. 611 00:25:50,940 --> 00:25:52,049 Does that make sense? 612 00:25:52,049 --> 00:25:53,559 So that's just an idea. 613 00:25:53,559 --> 00:25:55,809 It's actually a little bit more complicated than that, 614 00:25:55,809 --> 00:25:58,139 but that's just one festival with a bit 615 00:25:58,140 --> 00:26:02,940 of where you can see that the tokenizer actually matters. 616 00:26:02,940 --> 00:26:05,529 Great. 617 00:26:05,529 --> 00:26:07,849 OK, so evaluation challenges. 618 00:26:07,849 --> 00:26:08,529 There are many. 619 00:26:08,529 --> 00:26:10,690 I'll just talk about two really briefly. 620 00:26:10,690 --> 00:26:13,539 One, as I told you, there are two ways of doing evaluation 621 00:26:13,539 --> 00:26:14,486 for these MMLUs. 622 00:26:14,487 --> 00:26:16,070 Actually, there are many more than two 623 00:26:16,069 --> 00:26:17,799 but I gave you two examples. 624 00:26:17,799 --> 00:26:20,311 And it happens that for a long time, 625 00:26:20,311 --> 00:26:22,269 even though that was a very classical benchmark 626 00:26:22,269 --> 00:26:27,099 that everyone uses actually different companies 627 00:26:27,099 --> 00:26:32,139 and different organizations were actually 628 00:26:32,140 --> 00:26:34,870 using different ways of evaluating MMLU. 629 00:26:34,869 --> 00:26:37,909 And as a result, you get completely different results. 630 00:26:37,910 --> 00:26:42,820 For example, Llama-65b, which was the first model of meta 631 00:26:42,819 --> 00:26:47,809 in the llama series, had on HELM 63.7 accuracy 632 00:26:47,809 --> 00:26:53,049 but on this other benchmark had like 48.8. 633 00:26:53,049 --> 00:26:55,960 So really the way that you evaluate, and this is not even 634 00:26:55,960 --> 00:26:58,840 talking about prompting this is really just the way 635 00:26:58,839 --> 00:27:01,179 that you evaluate the models. 636 00:27:01,180 --> 00:27:02,560 Prompting is another issue. 637 00:27:02,559 --> 00:27:04,609 So really, there are a lot of inconsistencies. 638 00:27:04,609 --> 00:27:07,379 It's not as easy as it looks. 639 00:27:07,380 --> 00:27:08,190 First thing. 640 00:27:08,190 --> 00:27:08,860 Yeah, sorry. 641 00:27:08,859 --> 00:27:10,609 How can we make sure that all these models 642 00:27:10,609 --> 00:27:13,360 are trained on the benchmark? 643 00:27:13,361 --> 00:27:14,190 Second thing. 644 00:27:14,190 --> 00:27:15,590 This is a great question. 645 00:27:15,589 --> 00:27:17,359 Train test contamination. 646 00:27:17,359 --> 00:27:19,759 This is something which I would say 647 00:27:19,759 --> 00:27:24,170 is really important in academia in-- 648 00:27:24,170 --> 00:27:26,630 given that the talk is mostly about training large language 649 00:27:26,630 --> 00:27:29,720 models, for companies, it's maybe not that important 650 00:27:29,720 --> 00:27:33,140 because they know what they trained on. 651 00:27:33,140 --> 00:27:35,360 For us, we have no idea. 652 00:27:35,359 --> 00:27:37,339 So, for us, it's a real problem. 653 00:27:37,339 --> 00:27:39,470 So there are many different ways of trying 654 00:27:39,470 --> 00:27:42,658 to test whether the test set-- 655 00:27:42,657 --> 00:27:44,449 or sorry, whether the test set was actually 656 00:27:44,450 --> 00:27:45,680 in the training set. 657 00:27:45,680 --> 00:27:51,769 One, kind of, cute trick that people in the lab, 658 00:27:51,769 --> 00:27:54,230 in [? Tatsuo's ?] lab have found, is that what you can do 659 00:27:54,230 --> 00:27:57,019 is that given that most of the data set online 660 00:27:57,019 --> 00:28:00,173 are not randomized, you can just look at-- 661 00:28:00,173 --> 00:28:02,089 and that language models, what they do is just 662 00:28:02,089 --> 00:28:03,769 predict the next word. 663 00:28:03,769 --> 00:28:06,680 You can just look at the entire test set. 664 00:28:06,680 --> 00:28:09,410 What if you generate all the examples 665 00:28:09,410 --> 00:28:13,920 in order versus all the examples in a different order. 666 00:28:13,920 --> 00:28:17,420 And if it's more likely to generate a thing in order, given 667 00:28:17,420 --> 00:28:19,218 that there's no real order there, 668 00:28:19,218 --> 00:28:21,509 then it means that probably it was in the training set. 669 00:28:21,509 --> 00:28:23,059 Does that make sense? 670 00:28:23,059 --> 00:28:24,929 So there are many-- that's like one of them. 671 00:28:24,930 --> 00:28:26,513 There are many other ways of doing it. 672 00:28:26,512 --> 00:28:28,332 Train test contamination, again, not 673 00:28:28,333 --> 00:28:30,500 that important for development, really important for 674 00:28:30,500 --> 00:28:33,000 academic benchmarking. 675 00:28:33,000 --> 00:28:33,500 Great. 676 00:28:33,500 --> 00:28:34,958 So there are many other challenges, 677 00:28:34,958 --> 00:28:37,560 but I'll move on for now. 678 00:28:37,559 --> 00:28:38,059 Great. 679 00:28:38,059 --> 00:28:40,250 Data. 680 00:28:40,250 --> 00:28:43,309 So data is another really big topic. 681 00:28:43,309 --> 00:28:45,889 At a high level people just say you basically 682 00:28:45,890 --> 00:28:48,480 train large language models on all of internet. 683 00:28:48,480 --> 00:28:50,450 What does that even mean? 684 00:28:50,450 --> 00:28:53,160 So people sometimes say, well, of clean internet, 685 00:28:53,160 --> 00:28:55,820 which is even less defined. 686 00:28:55,819 --> 00:28:59,509 So internet is very dirty and really not representative 687 00:28:59,509 --> 00:29:00,779 of what we want in practice. 688 00:29:00,779 --> 00:29:03,990 If I download a random website right now, 689 00:29:03,990 --> 00:29:06,000 you would be shocked at what is in there. 690 00:29:06,000 --> 00:29:08,569 It's definitely not your Wikipedia. 691 00:29:08,569 --> 00:29:14,029 So I'll go really briefly on what people do. 692 00:29:14,029 --> 00:29:16,440 I can answer some questions, but I mean, 693 00:29:16,440 --> 00:29:19,190 data is on its own it's a huge topic. 694 00:29:19,190 --> 00:29:22,440 Basically, first what you do is download all of internet. 695 00:29:22,440 --> 00:29:25,970 What that means is that you use web crawlers that 696 00:29:25,970 --> 00:29:29,059 will go on every web page, on internet or every web page that 697 00:29:29,059 --> 00:29:31,500 is on Google. 698 00:29:31,500 --> 00:29:36,210 And that is around 250 billion pages right now. 699 00:29:36,210 --> 00:29:39,460 And that's around 1 petabyte of data. 700 00:29:39,460 --> 00:29:42,952 So this is actually a Common Crawl is one web crawler. 701 00:29:42,952 --> 00:29:45,119 So people don't usually write their own web crawlers 702 00:29:45,119 --> 00:29:47,709 what they do is that they use standard web crawlers, 703 00:29:47,710 --> 00:29:51,930 and Common Crawl is one of them that basically every month adds 704 00:29:51,930 --> 00:29:56,250 all the new websites that were added on internet that are found 705 00:29:56,250 --> 00:30:00,630 by Google, and they put it in a big basically a big data set. 706 00:30:00,630 --> 00:30:04,110 So that's-- on Common Crawl, you have around 250 billion pages 707 00:30:04,109 --> 00:30:04,659 right now. 708 00:30:04,660 --> 00:30:07,920 So 1E6 gigabytes of data. 709 00:30:07,920 --> 00:30:09,509 Once you have this-- 710 00:30:09,509 --> 00:30:11,400 so this is a random web page. 711 00:30:11,400 --> 00:30:14,485 Like literally random from this Common Crawl. 712 00:30:14,484 --> 00:30:16,109 And what you see is that one, it really 713 00:30:16,109 --> 00:30:18,939 doesn't look at type of things that you would usually see, 714 00:30:18,940 --> 00:30:21,420 but actually-- so this is an HTML page. 715 00:30:21,420 --> 00:30:24,690 It's hard to see, but if you look through 716 00:30:24,690 --> 00:30:26,470 will see some content. 717 00:30:26,470 --> 00:30:30,690 For example, here, Test King World 718 00:30:30,690 --> 00:30:33,920 is your ultimate source for the system x high performance 719 00:30:33,920 --> 00:30:34,420 server. 720 00:30:34,420 --> 00:30:35,470 And then you have three dots. 721 00:30:35,470 --> 00:30:37,730 So you don't even-- the sentence is not even finished. 722 00:30:37,730 --> 00:30:40,950 That's how random internet looks like. 723 00:30:40,950 --> 00:30:42,600 So, of course, it's not that useful 724 00:30:42,599 --> 00:30:44,549 if you just train a large language model 725 00:30:44,549 --> 00:30:45,909 to generate things like this. 726 00:30:45,910 --> 00:30:48,090 So what are some of the steps that are needed? 727 00:30:48,089 --> 00:30:51,236 First one, you extract the text from the HTML. 728 00:30:51,237 --> 00:30:53,070 So that's what I just tried to do by looking 729 00:30:53,069 --> 00:30:55,439 at basically the correct tags. 730 00:30:55,440 --> 00:30:57,640 There are a lot of challenges through this. 731 00:30:57,640 --> 00:30:59,730 For example, extracting math is actually 732 00:30:59,730 --> 00:31:02,339 very complicated, but pretty important for training 733 00:31:02,339 --> 00:31:03,869 large language models. 734 00:31:03,869 --> 00:31:05,679 Or for example, boilerplates. 735 00:31:05,680 --> 00:31:08,380 A lot of your forums will have the same type of headers, 736 00:31:08,380 --> 00:31:10,120 the same type of footers. 737 00:31:10,119 --> 00:31:13,349 You don't want to repeat all of this in your data, 738 00:31:13,349 --> 00:31:16,740 and then you will filter undesirable content. 739 00:31:16,740 --> 00:31:20,609 So not safe for work, harmful content, PII. 740 00:31:20,609 --> 00:31:22,709 So usually every company has basically 741 00:31:22,710 --> 00:31:26,279 a blacklist of websites that they don't 742 00:31:26,279 --> 00:31:27,670 want to train their models on. 743 00:31:27,670 --> 00:31:30,029 That blacklist is very long and you basically 744 00:31:30,029 --> 00:31:32,160 say if it comes from there, we don't train on this. 745 00:31:32,160 --> 00:31:34,060 There are other ways of doing these things. 746 00:31:34,059 --> 00:31:36,809 Is that you can train a small model for classifying what 747 00:31:36,809 --> 00:31:39,629 is PII, removing these things. 748 00:31:39,630 --> 00:31:40,510 It's hard. 749 00:31:40,509 --> 00:31:42,750 Every point here that I'm going to show you 750 00:31:42,750 --> 00:31:46,829 is a hard amount of work, but I'm just 751 00:31:46,829 --> 00:31:48,429 going to go quickly through it. 752 00:31:48,430 --> 00:31:50,140 So filter undesirable content. 753 00:31:50,140 --> 00:31:54,009 Second or fourth is de-duplication. 754 00:31:54,009 --> 00:31:57,990 As I said, you might have things like headers and footers 755 00:31:57,990 --> 00:31:59,920 in forums that are always the same. 756 00:31:59,920 --> 00:32:01,055 You want to remove that. 757 00:32:01,055 --> 00:32:02,430 Another thing that you might have 758 00:32:02,430 --> 00:32:05,789 is a lot of URLs that are different, but actually show 759 00:32:05,789 --> 00:32:08,129 the same website. 760 00:32:08,130 --> 00:32:13,530 And you might also have a lot of paragraphs that come from common 761 00:32:13,529 --> 00:32:16,740 books that are basically de-duplicated 1,000 times 762 00:32:16,740 --> 00:32:18,339 or 10,000 times on internet. 763 00:32:18,339 --> 00:32:20,009 So you have to de-duplicated. 764 00:32:20,009 --> 00:32:24,299 Also very challenging because you have to do that at scale. 765 00:32:24,299 --> 00:32:26,250 Once you do the de-duplication, you 766 00:32:26,250 --> 00:32:28,029 will do some heuristic filtering. 767 00:32:28,029 --> 00:32:31,379 You will try to remove low-quality documents. 768 00:32:31,380 --> 00:32:35,170 The way you do that are things like rules-based filtering. 769 00:32:35,170 --> 00:32:37,779 For example, if you see that there are some outlier tokens. 770 00:32:37,779 --> 00:32:39,690 If the distribution of tokens in the website 771 00:32:39,690 --> 00:32:42,160 is very different than the usual distribution of tokens, 772 00:32:42,160 --> 00:32:43,509 then it's probably some outlier. 773 00:32:43,509 --> 00:32:46,170 If you see that the length of the words in this website 774 00:32:46,170 --> 00:32:49,370 is super long, there's something strange going on that website. 775 00:32:49,369 --> 00:32:52,742 If you see that the website has only three words, 776 00:32:52,742 --> 00:32:54,159 maybe, is it worth training on it. 777 00:32:54,160 --> 00:32:54,660 Maybe not. 778 00:32:54,660 --> 00:32:58,590 If it has 10 million words, maybe there's something also 779 00:32:58,589 --> 00:33:00,299 wrong going on that page. 780 00:33:00,299 --> 00:33:01,509 So a lot of rules like this. 781 00:33:01,509 --> 00:33:02,009 Yes. 782 00:33:02,009 --> 00:33:04,379 Why do we filter out undesirable content 783 00:33:04,380 --> 00:33:08,310 from our data set instead of putting it in as, 784 00:33:08,309 --> 00:33:10,139 like, a supervised loss? 785 00:33:10,140 --> 00:33:14,500 Can we not just say, here's this like, hate speech website, 786 00:33:14,500 --> 00:33:17,309 let's actively try to-- 787 00:33:17,309 --> 00:33:19,889 let's actively penalize the model for getting it. 788 00:33:19,890 --> 00:33:22,690 We'll do exactly that, but not at this step. 789 00:33:22,690 --> 00:33:25,590 That's why the post-training will come from. 790 00:33:25,589 --> 00:33:30,119 Pretraining the idea is just to say 791 00:33:30,119 --> 00:33:34,459 I want to model, kind of, how humans speak, essentially. 792 00:33:34,460 --> 00:33:36,799 And I want to remove all these headers, footers 793 00:33:36,799 --> 00:33:38,700 and menus and things like this. 794 00:33:38,700 --> 00:33:41,759 But it's a very good idea that you just had. 795 00:33:41,759 --> 00:33:45,049 And that's exactly what we'll do later. 796 00:33:45,049 --> 00:33:47,190 Next step, model-based filtering. 797 00:33:47,190 --> 00:33:50,000 So once you filter a lot of data, what you will do-- 798 00:33:50,000 --> 00:33:51,799 that's actually a very cute trick. 799 00:33:51,799 --> 00:33:54,139 You will take all of Wikipedia and you 800 00:33:54,140 --> 00:33:56,450 will look at all the links that are 801 00:33:56,450 --> 00:33:58,440 linked through Wikipedia pages. 802 00:33:58,440 --> 00:34:01,080 Because probably if something is referenced by Wikipedia, 803 00:34:01,079 --> 00:34:02,990 it's probably some high-quality website. 804 00:34:02,990 --> 00:34:07,039 And you will train a classifier to predict whether something 805 00:34:07,039 --> 00:34:10,550 comes from-- whether a document comes from one 806 00:34:10,550 --> 00:34:13,190 of these references from Wikipedia 807 00:34:13,190 --> 00:34:15,269 or whether it's from the random web. 808 00:34:15,269 --> 00:34:17,250 And you will try to basically say, 809 00:34:17,250 --> 00:34:21,630 I want more of the things that come from Wikipedia references. 810 00:34:21,630 --> 00:34:23,449 Does that make sense? 811 00:34:23,449 --> 00:34:24,150 So yeah. 812 00:34:24,150 --> 00:34:26,420 So you will train a machine learning model. 813 00:34:26,420 --> 00:34:28,610 Usually also very simple models because you 814 00:34:28,610 --> 00:34:30,120 need to do that really at scale. 815 00:34:30,119 --> 00:34:34,138 I mean, just think about the 250 billion pages. 816 00:34:34,139 --> 00:34:37,650 Next one, you will try to classify your data 817 00:34:37,650 --> 00:34:41,019 into different domains. 818 00:34:41,019 --> 00:34:43,809 You will say, OK, this is entertainment, this is books, 819 00:34:43,809 --> 00:34:46,389 this is code, this is like these type of domains. 820 00:34:46,389 --> 00:34:51,010 And then you will try to either up or down weight 821 00:34:51,010 --> 00:34:52,620 some of the domains. 822 00:34:52,619 --> 00:34:54,359 For example, you might say-- 823 00:34:54,360 --> 00:34:57,320 you might see that actually if you train more on code, then 824 00:34:57,320 --> 00:34:59,320 actually your model becomes better on reasoning. 825 00:34:59,320 --> 00:35:01,470 So that's something that people usually say in 826 00:35:01,469 --> 00:35:02,529 a very hand-wavy way. 827 00:35:02,530 --> 00:35:04,393 If you train your model more on code, 828 00:35:04,393 --> 00:35:05,559 actually it helps reasoning. 829 00:35:05,559 --> 00:35:08,849 So you want to update the coding distribution 830 00:35:08,849 --> 00:35:11,639 because that helps for general language modeling skills. 831 00:35:11,639 --> 00:35:16,079 Books is usually also another one that people usually update. 832 00:35:16,079 --> 00:35:18,719 Entertainment, they usually down weight. 833 00:35:18,719 --> 00:35:19,929 So things like this. 834 00:35:19,929 --> 00:35:24,000 Of course, you want to do it-- so people used to do it, maybe 835 00:35:24,000 --> 00:35:25,420 kind of heuristically. 836 00:35:25,420 --> 00:35:27,480 Now there's entire pipelines that we'll 837 00:35:27,480 --> 00:35:30,240 talk about of how to do these things slightly 838 00:35:30,239 --> 00:35:33,419 more automatically. 839 00:35:33,420 --> 00:35:37,909 And then at the end of training, you usually train-- 840 00:35:37,909 --> 00:35:40,144 after training on all of this data that we saw 841 00:35:40,144 --> 00:35:42,739 you usually train on very high quality data 842 00:35:42,739 --> 00:35:46,339 at the end of training your large language model where you 843 00:35:46,340 --> 00:35:47,640 decrease your learning rate. 844 00:35:47,639 --> 00:35:49,400 And that basically means that you're, 845 00:35:49,400 --> 00:35:52,860 kind of, overfitting your model on a very high quality data. 846 00:35:52,860 --> 00:35:55,289 So usually what you do there is Wikipedia. 847 00:35:55,289 --> 00:35:57,889 You basically overfit on Wikipedia 848 00:35:57,889 --> 00:36:04,190 and you overfit on, like, human data that was collected. 849 00:36:04,190 --> 00:36:06,380 The other thing is like continual pretraining 850 00:36:06,380 --> 00:36:07,920 for getting longer context. 851 00:36:07,920 --> 00:36:09,997 I'm going to skip over all of these things. 852 00:36:09,996 --> 00:36:12,079 But that's just to give you a sense of how hard it 853 00:36:12,079 --> 00:36:15,230 is when people just say I'm going to train on internet, 854 00:36:15,230 --> 00:36:17,329 that's a lot of work. 855 00:36:17,329 --> 00:36:19,789 And, really, we haven't figured it out yet. 856 00:36:19,789 --> 00:36:23,300 So collecting well data is a huge part 857 00:36:23,300 --> 00:36:24,940 of practical, large language model. 858 00:36:24,940 --> 00:36:26,690 Some might say that it's actually the key. 859 00:36:26,690 --> 00:36:27,289 Yes. 860 00:36:27,289 --> 00:36:29,039 [INAUDIBLE] about data. 861 00:36:29,039 --> 00:36:30,210 So basic question. 862 00:36:30,210 --> 00:36:33,720 So usually when you start with like a petabyte of data, 863 00:36:33,719 --> 00:36:35,189 after you go through all the steps, 864 00:36:35,190 --> 00:36:37,550 what's the typical amount of data you have remaining. 865 00:36:37,550 --> 00:36:40,940 And then how large a team does it typically 866 00:36:40,940 --> 00:36:43,460 take to go through all the data steps you talked about? 867 00:36:43,460 --> 00:36:45,230 Sorry how la-- is your question how large 868 00:36:45,230 --> 00:36:46,920 is the data after you filter? 869 00:36:46,920 --> 00:36:47,420 Yeah. 870 00:36:47,420 --> 00:36:49,711 After you filter and then you go through all the steps. 871 00:36:49,711 --> 00:36:52,250 How large a team do you need to go through, like, 872 00:36:52,250 --> 00:36:54,710 all the filtration steps you mentioned. 873 00:36:54,710 --> 00:36:56,420 How slow is it or-- 874 00:36:56,420 --> 00:37:00,260 How many people would you need to be 875 00:37:00,260 --> 00:37:02,390 able to do this [INAUDIBLE]? 876 00:37:02,389 --> 00:37:03,539 OK that's a great question. 877 00:37:03,539 --> 00:37:06,590 I'm going to somewhat answer about the data. 878 00:37:06,590 --> 00:37:10,070 How large is the data set at the end of this slide. 879 00:37:10,070 --> 00:37:15,600 For number of people that work on it, that's a good question. 880 00:37:15,599 --> 00:37:19,769 I'm actually not quite sure, but I would say, yeah, 881 00:37:19,769 --> 00:37:22,519 I actually don't quite know but I 882 00:37:22,519 --> 00:37:25,070 would say it's probably even bigger than the number of people 883 00:37:25,070 --> 00:37:29,809 that work on the tuning of the pretraining of the model. 884 00:37:29,809 --> 00:37:34,710 So the data is bigger than the modeling aspect. 885 00:37:34,710 --> 00:37:37,949 Yeah, I don't think I have a good sense. 886 00:37:37,949 --> 00:37:41,460 I would say probably in LLAMA's team, which have 70-ish people, 887 00:37:41,460 --> 00:37:45,199 I would say maybe 15 work on data. 888 00:37:45,199 --> 00:37:46,246 Yeah. 889 00:37:46,246 --> 00:37:48,329 All these things, you don't need that many people, 890 00:37:48,329 --> 00:37:49,621 you need a lot of compute also. 891 00:37:49,621 --> 00:37:52,759 Because for data you need a lot of CPUs. 892 00:37:52,760 --> 00:37:53,370 So, yeah. 893 00:37:53,369 --> 00:37:54,889 And I'll answer the second question 894 00:37:54,889 --> 00:37:56,329 at the end of this slide. 895 00:37:56,329 --> 00:37:59,909 So as I just, kind of, alluded to really, 896 00:37:59,909 --> 00:38:02,237 we haven't solved data at all for pretraining. 897 00:38:02,237 --> 00:38:04,279 So there's a lot of research that has to be done. 898 00:38:04,280 --> 00:38:07,250 First, how do you process these things super efficiently? 899 00:38:07,250 --> 00:38:09,320 Second, how do you balance kind of all 900 00:38:09,320 --> 00:38:10,670 of these different domains? 901 00:38:10,670 --> 00:38:12,510 Can you do synthetic data generation? 902 00:38:12,510 --> 00:38:14,210 That's actually a big one right now. 903 00:38:14,210 --> 00:38:16,132 And because we don't have-- 904 00:38:16,132 --> 00:38:18,049 we'll talk about that later, but we don't have 905 00:38:18,050 --> 00:38:20,539 enough data on the internet. 906 00:38:20,539 --> 00:38:23,789 Can you use multimodal data instead of just text data? 907 00:38:23,789 --> 00:38:28,039 And how does that improve even your text performance? 908 00:38:28,039 --> 00:38:30,139 There's a lot of secrecy because, really, this 909 00:38:30,139 --> 00:38:33,369 is the key of most of the pretraining large language 910 00:38:33,369 --> 00:38:34,210 models. 911 00:38:34,210 --> 00:38:39,550 So for competitive dynamics, usually these companies 912 00:38:39,550 --> 00:38:41,780 don't talk about how they do the data collection. 913 00:38:41,780 --> 00:38:44,030 And also there's a copyright liability issue. 914 00:38:44,030 --> 00:38:45,070 They definitely don't want to tell you 915 00:38:45,070 --> 00:38:47,153 that they've trained on books even though they did 916 00:38:47,152 --> 00:38:50,529 because if not can sue them. 917 00:38:50,530 --> 00:38:52,280 Common academic benchmarks. 918 00:38:52,280 --> 00:38:54,610 So that will, kind of, answer what you asked. 919 00:38:54,610 --> 00:38:57,595 It started-- so those are the smaller ones. 920 00:38:57,594 --> 00:38:58,969 The names are not that important, 921 00:38:58,969 --> 00:39:02,289 but it started from around $150 billion tokens, which are 922 00:39:02,289 --> 00:39:04,519 around 800 gigabytes of data. 923 00:39:04,519 --> 00:39:06,460 And now it's around 15 trillion-- 924 00:39:06,460 --> 00:39:09,340 15 trillion tokens, which is also 925 00:39:09,340 --> 00:39:12,586 the size of the models that are-- right now the best models 926 00:39:12,586 --> 00:39:14,420 are probably trained on that amount of data. 927 00:39:14,420 --> 00:39:18,456 So 15 trillion tokens, which is probably, 928 00:39:18,456 --> 00:39:20,539 I guess, two orders of magnitude bigger than that. 929 00:39:20,539 --> 00:39:23,719 So 80E3 gigabyte. 930 00:39:23,719 --> 00:39:29,379 So that would be around 100 to 1,000 times filtering 931 00:39:29,380 --> 00:39:32,769 of the Common Crawl, if I'm not mistaken. 932 00:39:32,769 --> 00:39:34,480 So, yeah. 933 00:39:34,480 --> 00:39:37,030 One very famous one is the Pile. 934 00:39:37,030 --> 00:39:39,380 So this is an academic benchmark, the Pile. 935 00:39:39,380 --> 00:39:42,500 And we can just look at what distribution of data they have. 936 00:39:42,500 --> 00:39:46,900 It's things like archive, PubMed Central, 937 00:39:46,900 --> 00:39:50,139 which is all the biology stuff. 938 00:39:50,139 --> 00:39:55,779 Here it's Wikipedia, you see Stack Exchange, some GitHub 939 00:39:55,780 --> 00:39:58,360 and some books and things like this. 940 00:39:58,360 --> 00:39:59,960 Again, this is on the smaller side. 941 00:39:59,960 --> 00:40:03,298 So this is-- if we look at here, this is on 280B so, in reality, 942 00:40:03,297 --> 00:40:05,589 it's like 100 times bigger so you cannot have that much 943 00:40:05,590 --> 00:40:09,280 of GitHub and of Wikipedia. 944 00:40:09,280 --> 00:40:11,330 In terms of closed source models. 945 00:40:11,329 --> 00:40:14,590 Just to give you an idea, Llama 2 946 00:40:14,590 --> 00:40:16,970 it was trained on 2 trillion tokens, 947 00:40:16,969 --> 00:40:19,839 Llama 3 15 trillion tokens, which is currently 948 00:40:19,840 --> 00:40:22,550 the best model that we know on how much it was trained on, 949 00:40:22,550 --> 00:40:26,980 which is the same thing as is the best academic or the biggest 950 00:40:26,980 --> 00:40:29,300 academic benchmark, which is 15 trillion tokens. 951 00:40:29,300 --> 00:40:31,090 GPT4 we don't really but it's probably 952 00:40:31,090 --> 00:40:33,660 in the same order of magnitude or it's probably around that. 953 00:40:33,659 --> 00:40:36,809 Actually, it's probably around 13 from leaks. 954 00:40:36,809 --> 00:40:39,860 If the leaks are true. 955 00:40:39,860 --> 00:40:41,059 Great. 956 00:40:41,059 --> 00:40:43,400 So scaling laws. 957 00:40:43,400 --> 00:40:45,840 Any other questions on data before we go to scaling laws? 958 00:40:48,929 --> 00:40:51,069 Sorry I know I'm giving you a lot of information, 959 00:40:51,070 --> 00:40:54,450 but there's a lot into training, large language models. 960 00:40:54,449 --> 00:40:56,759 Great scaling laws. 961 00:40:56,760 --> 00:41:01,680 So the idea is that what people saw around 2020, or at least 962 00:41:01,679 --> 00:41:05,519 from a long time, but they've been able to theoretically show 963 00:41:05,519 --> 00:41:07,809 it or empirically show it since 2020, 964 00:41:07,809 --> 00:41:09,960 is that the more data you train your models on 965 00:41:09,960 --> 00:41:12,548 and the larger the models, the better the performance. 966 00:41:12,547 --> 00:41:14,339 This is actually pretty different than what 967 00:41:14,340 --> 00:41:15,600 you've seen in this class. 968 00:41:15,599 --> 00:41:17,659 In this class we teach you about overfitting. 969 00:41:17,659 --> 00:41:20,699 Overfitting doesn't happen with large language models. 970 00:41:20,699 --> 00:41:23,489 Larger models, better performance. 971 00:41:23,489 --> 00:41:25,739 It's something that really took a long time 972 00:41:25,739 --> 00:41:29,879 for the community who took this type of class to realize. 973 00:41:29,880 --> 00:41:33,539 But for the exam, overfitting exists. 974 00:41:33,539 --> 00:41:38,519 So, OK, the idea of scaling loss is that if-- given that more 975 00:41:38,519 --> 00:41:40,980 data and larger models will always 976 00:41:40,980 --> 00:41:42,990 give you better performance, can we 977 00:41:42,989 --> 00:41:46,139 predict how much better your performance will 978 00:41:46,139 --> 00:41:50,230 be if you increase the amount of data and the size of your model? 979 00:41:50,230 --> 00:41:52,539 And surprisingly, it works. 980 00:41:52,539 --> 00:41:55,389 So here you see three plots from a very famous paper called 981 00:41:55,389 --> 00:41:57,759 Scaling Laws from OpenAI. 982 00:41:57,760 --> 00:42:00,020 Here you see on the x-axis compute. 983 00:42:00,019 --> 00:42:01,730 So how much did you train-- 984 00:42:01,730 --> 00:42:04,010 like, how much compute did you spend for training? 985 00:42:04,010 --> 00:42:05,390 And here you see test loss. 986 00:42:05,389 --> 00:42:08,000 So this is essentially, I mean, perplexity, 987 00:42:08,000 --> 00:42:09,489 but it's your validation loss. 988 00:42:09,489 --> 00:42:11,569 So it's a log of the perplexity. 989 00:42:11,570 --> 00:42:15,050 And if you put these two on log scale, 990 00:42:15,050 --> 00:42:19,750 then you see that the performance or the-- 991 00:42:19,750 --> 00:42:22,539 sorry, the scaling law is linear. 992 00:42:22,539 --> 00:42:25,029 That means that if you increase your compute 993 00:42:25,030 --> 00:42:29,050 by a certain amount, you can say by how much your test loss will 994 00:42:29,050 --> 00:42:30,250 actually decrease. 995 00:42:30,250 --> 00:42:33,420 Same thing with data and same thing for parameters. 996 00:42:33,420 --> 00:42:35,510 If you increase the data set size, 997 00:42:35,510 --> 00:42:38,470 your loss will decrease by an amount 998 00:42:38,469 --> 00:42:40,129 that is somewhat predictable. 999 00:42:40,130 --> 00:42:42,730 If you increase the number of parameters, 1000 00:42:42,730 --> 00:42:44,380 the loss will decrease by an amount, 1001 00:42:44,380 --> 00:42:45,630 which is somewhat predictable. 1002 00:42:45,630 --> 00:42:47,980 This is really amazing. 1003 00:42:47,980 --> 00:42:49,550 Very surprising. 1004 00:42:49,550 --> 00:42:52,700 I mean, it looks innocuous when you look at these type of plots, 1005 00:42:52,699 --> 00:42:55,210 but that's crazy because it means that you can predict 1006 00:42:55,210 --> 00:42:58,159 how well we're going to perform in two or three years, 1007 00:42:58,159 --> 00:42:59,960 depending on how much compute we will add, 1008 00:42:59,960 --> 00:43:01,630 assuming that these things will hold. 1009 00:43:01,630 --> 00:43:04,240 There's nothing theoretical about it. 1010 00:43:04,239 --> 00:43:05,859 Yes. 1011 00:43:05,860 --> 00:43:06,470 Two things. 1012 00:43:06,469 --> 00:43:08,386 One, what is the loss that they're using here. 1013 00:43:08,387 --> 00:43:09,490 Is this perplexity? 1014 00:43:09,489 --> 00:43:13,439 So it's-- I said perplexity was like 2 to the power of the loss. 1015 00:43:13,440 --> 00:43:17,150 So this is the power of the perplexity. 1016 00:43:17,150 --> 00:43:19,119 And then the second thing is, when 1017 00:43:19,119 --> 00:43:21,069 you increase the number of parameters 1018 00:43:21,070 --> 00:43:24,071 or you increase the data set size [INAUDIBLE] data 1019 00:43:24,070 --> 00:43:26,692 [INAUDIBLE] times, doesn't that just inherently 1020 00:43:26,693 --> 00:43:27,610 increase your compute? 1021 00:43:27,610 --> 00:43:30,099 Like does all of this [INAUDIBLE] come to just how 1022 00:43:30,099 --> 00:43:31,329 [INAUDIBLE] you [INAUDIBLE]? 1023 00:43:31,329 --> 00:43:31,679 Yes. 1024 00:43:31,679 --> 00:43:32,500 --or something specific [INAUDIBLE]? 1025 00:43:32,500 --> 00:43:33,708 No, this is a great question. 1026 00:43:33,708 --> 00:43:37,119 So the compute here is actually a factor of two things, the data 1027 00:43:37,119 --> 00:43:38,179 and the parameter. 1028 00:43:38,179 --> 00:43:40,119 What I'm showing here is that you can-- 1029 00:43:40,119 --> 00:43:42,049 well, actually, we're going to talk about that in details. 1030 00:43:42,050 --> 00:43:44,450 But basically, if you increase the number of parameters, 1031 00:43:44,449 --> 00:43:48,129 you should increase the number of data that you have. 1032 00:43:48,130 --> 00:43:50,079 So you actually don't go multiple times 1033 00:43:50,079 --> 00:43:51,289 to the same data set. 1034 00:43:51,289 --> 00:43:56,019 No one does epochs in at least not yet 1035 00:43:56,019 --> 00:43:59,829 because we haven't still kind of enough data. 1036 00:43:59,829 --> 00:44:01,699 So yeah, this is all the same trend, 1037 00:44:01,699 --> 00:44:04,899 which is increase compute decrease loss. 1038 00:44:04,900 --> 00:44:06,010 Yes. 1039 00:44:06,010 --> 00:44:09,531 Have we seen the numbers for the last two years or this 1040 00:44:09,530 --> 00:44:10,809 is still holding? 1041 00:44:10,809 --> 00:44:13,039 It is still holding. 1042 00:44:13,039 --> 00:44:16,389 I don't have good numbers to show you, 1043 00:44:16,389 --> 00:44:20,929 but it is still holding, surprisingly. 1044 00:44:20,929 --> 00:44:21,659 Yes. 1045 00:44:21,659 --> 00:44:23,809 Is there no evidence that control quality density 1046 00:44:23,809 --> 00:44:25,170 will ever plateau? 1047 00:44:25,170 --> 00:44:28,650 In theory, we would expect it plateau, [INAUDIBLE]? 1048 00:44:28,650 --> 00:44:33,030 No empirical evidence of plateauing anytime soon. 1049 00:44:33,030 --> 00:44:34,080 Why? 1050 00:44:34,079 --> 00:44:35,909 We don't know. 1051 00:44:35,909 --> 00:44:37,440 Will it happen? 1052 00:44:37,440 --> 00:44:37,940 Probably. 1053 00:44:37,940 --> 00:44:39,940 I mean, it doesn't need to because it's actually 1054 00:44:39,940 --> 00:44:40,710 in log scale. 1055 00:44:40,710 --> 00:44:43,780 So it's not like as if it had to go. 1056 00:44:43,780 --> 00:44:44,830 It had to plateau. 1057 00:44:44,829 --> 00:44:47,362 Like mathematically, it could continue decreasing like this. 1058 00:44:47,362 --> 00:44:49,320 I mean, most people think that it will probably 1059 00:44:49,320 --> 00:44:50,498 plateau at some point. 1060 00:44:50,498 --> 00:44:51,289 We don't know when. 1061 00:44:54,480 --> 00:44:57,179 So that's-- I'll talk more about scaling laws now. 1062 00:44:57,179 --> 00:44:59,969 So why are scaling laws really cool? 1063 00:44:59,969 --> 00:45:02,159 Imagine that I gave you-- 1064 00:45:02,159 --> 00:45:05,489 you're very fortunate I gave you 10,000 GPUs for this month. 1065 00:45:05,489 --> 00:45:07,309 What model will you train? 1066 00:45:07,309 --> 00:45:09,549 How do you even go about answering that question? 1067 00:45:09,550 --> 00:45:12,430 And I mean, this is a hypothetical, 1068 00:45:12,429 --> 00:45:16,109 but that's exactly what these companies are faced with. 1069 00:45:16,110 --> 00:45:19,680 The old pipeline, which was basically 1070 00:45:19,679 --> 00:45:21,609 tune hyperparameters on the big models. 1071 00:45:21,610 --> 00:45:24,360 So let's say I have 30 days, I will train 1072 00:45:24,360 --> 00:45:26,800 30 models for one day each. 1073 00:45:26,800 --> 00:45:30,130 I will pick the best one and that will be the final model 1074 00:45:30,130 --> 00:45:32,140 that I will use in production. 1075 00:45:32,139 --> 00:45:34,119 That means that the model that I actually used 1076 00:45:34,119 --> 00:45:36,670 was only trained for one day. 1077 00:45:36,670 --> 00:45:40,369 The new pipeline is that you first find a scaling recipe. 1078 00:45:40,369 --> 00:45:43,404 So you find something that tells you, for example, 1079 00:45:43,405 --> 00:45:45,280 like one common thing is that if you increase 1080 00:45:45,280 --> 00:45:46,930 the size of your model, you should decrease your learning 1081 00:45:46,929 --> 00:45:47,429 rate. 1082 00:45:47,429 --> 00:45:49,119 So you find a scaling recipe such 1083 00:45:49,119 --> 00:45:52,789 that you know if I increase the size of my model, 1084 00:45:52,789 --> 00:45:55,029 here's what I should do with some hyperparameters. 1085 00:45:55,030 --> 00:45:57,730 Then you tune your hyperparameters 1086 00:45:57,730 --> 00:46:00,650 on smaller models of different sizes. 1087 00:46:00,650 --> 00:46:03,519 Let's say I will say for three days, of my 30 days, 1088 00:46:03,519 --> 00:46:05,509 I will train many different models. 1089 00:46:05,510 --> 00:46:07,090 And I will do hyperparameter tuning 1090 00:46:07,090 --> 00:46:09,470 on these small models, each of different sizes. 1091 00:46:09,469 --> 00:46:11,949 Then I will fit a scaling law and try 1092 00:46:11,949 --> 00:46:15,669 to extrapolate from these smaller models, which 1093 00:46:15,670 --> 00:46:20,019 one will be the best if I train it for much longer-- 1094 00:46:20,019 --> 00:46:22,969 or sorry if I train it for a larger model. 1095 00:46:22,969 --> 00:46:24,969 And then I will train the final huge model 1096 00:46:24,969 --> 00:46:28,179 for 27 days instead of just one day. 1097 00:46:28,179 --> 00:46:31,599 So the new pipeline is not train things 1098 00:46:31,599 --> 00:46:34,087 or do hyperparameter tuning on the real scale of the model 1099 00:46:34,088 --> 00:46:35,630 that you're going to use in practice, 1100 00:46:35,630 --> 00:46:39,500 but do things on smaller ones at different scales. 1101 00:46:39,500 --> 00:46:41,650 Try to predict how well they will perform 1102 00:46:41,650 --> 00:46:43,059 once you make them bigger. 1103 00:46:43,059 --> 00:46:46,449 I will give-- I will give you a very concrete example right now. 1104 00:46:46,449 --> 00:46:49,719 Let's say transformers versus LSTMs. 1105 00:46:49,719 --> 00:46:51,821 Let's say you have these 10,000 GPUs, 1106 00:46:51,822 --> 00:46:53,780 you are not sure which one you should be using. 1107 00:46:53,780 --> 00:46:55,572 Should I be using a transformer-based model 1108 00:46:55,572 --> 00:46:56,750 or LSTM-based model. 1109 00:46:56,750 --> 00:46:58,929 What I will do is I will train transformers 1110 00:46:58,929 --> 00:47:00,169 at different scales. 1111 00:47:00,170 --> 00:47:02,780 So here you see different parameters on the x-axis, 1112 00:47:02,780 --> 00:47:04,460 y-axis is my test source. 1113 00:47:04,460 --> 00:47:08,449 I will then train different LSTMs at different scales. 1114 00:47:08,449 --> 00:47:11,139 Once I have these points, I will see oh it, kind of, 1115 00:47:11,139 --> 00:47:12,619 fits a scaling law. 1116 00:47:12,619 --> 00:47:14,259 I will fit my scaling law and then 1117 00:47:14,260 --> 00:47:18,860 I will be able to predict if I had 10 times more compute, 1118 00:47:18,860 --> 00:47:21,380 here's how well I would perform for the LSTM. 1119 00:47:21,380 --> 00:47:23,570 It's actually slightly less linear for the LSTM, 1120 00:47:23,570 --> 00:47:26,750 but you can probably try to predict where you would end up. 1121 00:47:26,750 --> 00:47:28,659 And clearly from this plot, you would see 1122 00:47:28,659 --> 00:47:30,789 that transformers are better. 1123 00:47:30,789 --> 00:47:33,369 One thing to notice when you read these type of scaling laws 1124 00:47:33,369 --> 00:47:35,739 is that there are two things that are important. 1125 00:47:35,739 --> 00:47:40,359 One is really your scaling rate, which 1126 00:47:40,360 --> 00:47:45,740 is the slope of the-- the slope of the scaling law. 1127 00:47:45,739 --> 00:47:49,839 The other thing is your intercept, 1128 00:47:49,840 --> 00:47:52,180 you could start worse, but actually 1129 00:47:52,179 --> 00:47:53,659 become better over time. 1130 00:47:53,659 --> 00:47:55,989 It just happens that LSTMs are worse for both. 1131 00:47:55,989 --> 00:47:58,689 But I could show you another one where things-- 1132 00:47:58,690 --> 00:48:01,450 you can predict that actually after a certain scale 1133 00:48:01,449 --> 00:48:04,389 you're better off using that type of model than others. 1134 00:48:04,389 --> 00:48:08,500 So that's why scaling laws are actually really useful. 1135 00:48:08,500 --> 00:48:12,099 Any questions on that? 1136 00:48:12,099 --> 00:48:12,799 Yeah. 1137 00:48:12,800 --> 00:48:15,490 So these are all, kind of, very-- 1138 00:48:15,489 --> 00:48:18,919 how sensitive are these to small differences in the architecture. 1139 00:48:18,920 --> 00:48:21,923 Like one like transformer architecture 1140 00:48:21,922 --> 00:48:23,589 versus another transformer architecture. 1141 00:48:23,590 --> 00:48:26,220 Do you think we have to fit your own curve 1142 00:48:26,219 --> 00:48:28,719 and, basically, say like oh scaling laws tell me this should 1143 00:48:28,719 --> 00:48:31,329 be some logarithmic function. 1144 00:48:31,329 --> 00:48:33,519 Like, let me extrapolate that for 1145 00:48:33,519 --> 00:48:35,179 my own specific architecture. 1146 00:48:35,179 --> 00:48:38,139 Yeah, so usually, for example, if you're an academic 1147 00:48:38,139 --> 00:48:40,989 and you want to-- now at least that's pretty recent 1148 00:48:40,989 --> 00:48:43,716 and you want to propose a new activation. 1149 00:48:43,717 --> 00:48:45,050 That's exactly what you will do. 1150 00:48:45,050 --> 00:48:47,470 You will fit a scaling law, show another scaling law 1151 00:48:47,469 --> 00:48:49,413 with the standard like, I don't GELU 1152 00:48:49,413 --> 00:48:50,829 and you will say that it's better. 1153 00:48:50,829 --> 00:48:53,121 In reality, once you start thinking about it in scaling 1154 00:48:53,121 --> 00:48:55,552 laws terms, you really realize that actually 1155 00:48:55,552 --> 00:48:57,219 all the architecture differences that we 1156 00:48:57,219 --> 00:48:59,649 can make, like the small, minor ones, all they do 1157 00:48:59,650 --> 00:49:03,160 is maybe change a little bit the intercept. 1158 00:49:03,159 --> 00:49:05,649 But really that doesn't matter because just 1159 00:49:05,650 --> 00:49:09,700 train it for 10 hours longer or like wait for the next computer 1160 00:49:09,699 --> 00:49:12,016 GPUs and these things are really secondary. 1161 00:49:12,016 --> 00:49:14,099 Which is exactly why I was telling you originally, 1162 00:49:14,099 --> 00:49:17,089 people spend too much time on the architecture and losses. 1163 00:49:17,090 --> 00:49:19,039 In reality, these things don't matter as much. 1164 00:49:19,039 --> 00:49:19,949 Data though. 1165 00:49:19,949 --> 00:49:23,119 If you use good data, you will have much better scaling laws 1166 00:49:23,119 --> 00:49:24,449 than if you use bad data. 1167 00:49:24,449 --> 00:49:27,379 So that really matters. 1168 00:49:27,380 --> 00:49:29,630 Another really cool thing you can do with scaling laws 1169 00:49:29,630 --> 00:49:33,950 is that you can ask yourself, how to optimally allocate 1170 00:49:33,949 --> 00:49:35,129 training resources. 1171 00:49:35,130 --> 00:49:37,019 Should I train larger models. 1172 00:49:37,019 --> 00:49:39,719 Because we saw that it's better when you train larger models, 1173 00:49:39,719 --> 00:49:42,359 but we saw that it's also better when you use more data. 1174 00:49:42,360 --> 00:49:43,860 So which one should I do? 1175 00:49:43,860 --> 00:49:46,050 Should I just train on more data, a smaller model, 1176 00:49:46,050 --> 00:49:49,340 or should I train a larger model on less data? 1177 00:49:49,340 --> 00:49:53,840 So Chinchilla is a very famous paper that first showed this. 1178 00:49:53,840 --> 00:49:55,760 The way they did it, I want to give you 1179 00:49:55,760 --> 00:49:58,400 a little bit of a sense of what these plots are. 1180 00:49:58,400 --> 00:50:00,869 Here you see training loss again on the x-axis, 1181 00:50:00,869 --> 00:50:04,099 you see parameter differences, sorry, parameter size-- 1182 00:50:04,099 --> 00:50:04,980 number of parameters. 1183 00:50:04,980 --> 00:50:06,119 So the size of the model. 1184 00:50:06,119 --> 00:50:07,909 And here all these curves are what 1185 00:50:07,909 --> 00:50:13,929 we call ISO flops, which is that all the models on this curve 1186 00:50:13,929 --> 00:50:17,089 have been trained with the same amount of compute. 1187 00:50:17,090 --> 00:50:19,230 The way that you do that is that you train-- 1188 00:50:19,230 --> 00:50:20,115 you change. 1189 00:50:20,114 --> 00:50:22,489 Sorry, you vary the number of tokens that were trained on 1190 00:50:22,489 --> 00:50:25,009 and the size of the models, but you vary in such a way 1191 00:50:25,010 --> 00:50:27,620 that the total compute is constant, OK. 1192 00:50:27,619 --> 00:50:29,869 So all these curves that you see with different colors 1193 00:50:29,869 --> 00:50:32,519 have different amount of compute that were trained on. 1194 00:50:32,519 --> 00:50:35,369 Then you take the best one for each of those curves. 1195 00:50:35,369 --> 00:50:38,630 Once you have the best one for each of those curves, 1196 00:50:38,630 --> 00:50:44,150 you can ask-- you can plot how much flops it was 1197 00:50:44,150 --> 00:50:47,329 and which curve were you on and how much parameters 1198 00:50:47,329 --> 00:50:50,819 did you actually use for training that specific point. 1199 00:50:50,820 --> 00:50:55,130 You put that on the log log scale again and now 1200 00:50:55,130 --> 00:50:56,970 you fit a scaling law again. 1201 00:50:56,969 --> 00:50:59,750 So now I have something which tells me 1202 00:50:59,750 --> 00:51:03,739 if I want to train a model of 10 to the power 23 flops, here is 1203 00:51:03,739 --> 00:51:06,089 exactly the number of parameters that I should be using. 1204 00:51:06,090 --> 00:51:07,789 100 B. 1205 00:51:07,789 --> 00:51:11,300 And you can do the same thing with flops and tokens. 1206 00:51:11,300 --> 00:51:13,280 So now you can predict-- 1207 00:51:13,280 --> 00:51:16,660 if I tell you exactly I have one month of compute, 1208 00:51:16,659 --> 00:51:18,759 what size of model should I be training? 1209 00:51:18,760 --> 00:51:21,910 Fit the scaling law, and I tell you. 1210 00:51:21,909 --> 00:51:23,589 Of course that all looks beautiful. 1211 00:51:23,590 --> 00:51:26,960 In reality like there's a lot of small things of like, 1212 00:51:26,960 --> 00:51:29,179 should you be counting, like, embedding parameters, 1213 00:51:29,179 --> 00:51:30,949 there's a lot of complexities. 1214 00:51:30,949 --> 00:51:35,289 But if you do things well, these things actually do hold. 1215 00:51:35,289 --> 00:51:38,920 So the optimal number of parameters that Chinchilla paper 1216 00:51:38,920 --> 00:51:42,730 have found is to use 20 tokens for every parameter 1217 00:51:42,730 --> 00:51:44,019 that you train. 1218 00:51:44,019 --> 00:51:45,469 So if you add one more parameter, 1219 00:51:45,469 --> 00:51:49,299 you should train your thing on-- your model on 20 more tokens. 1220 00:51:49,300 --> 00:51:53,180 So one caveat here is that this is optimal training resources. 1221 00:51:53,179 --> 00:51:57,099 So that is telling me if you have 10 to the power, 23 flops 1222 00:51:57,099 --> 00:52:00,789 or if you have 100, I don't know how much that is, $100 million 1223 00:52:00,789 --> 00:52:02,869 or 10-- no, that's much less, actually. 1224 00:52:02,869 --> 00:52:05,199 Let's say I have $5 million to train 1225 00:52:05,199 --> 00:52:07,029 my best model that gets the lowest 1226 00:52:07,030 --> 00:52:09,710 loss what would I train on? 1227 00:52:09,710 --> 00:52:12,829 In reality, these companies need to think about inference also. 1228 00:52:12,829 --> 00:52:17,750 If you have a smaller model, they will spend less over time. 1229 00:52:17,750 --> 00:52:20,309 So actually, if you consider the inference cost, 1230 00:52:20,309 --> 00:52:23,000 you have other papers that try to show that, it's 1231 00:52:23,000 --> 00:52:26,900 around 150 parameters, sorry-- 1232 00:52:26,900 --> 00:52:29,930 tokens per parameters, because you prefer having a smaller 1233 00:52:29,929 --> 00:52:32,779 model because over time you're going 1234 00:52:32,780 --> 00:52:37,560 to actually spend less money on inference of these models. 1235 00:52:37,559 --> 00:52:42,409 So 150 to 1, that's around what the best models are trained 1236 00:52:42,409 --> 00:52:45,109 on right now, at least the ones that are 1237 00:52:45,110 --> 00:52:49,930 used in practice in production. 1238 00:52:49,929 --> 00:52:51,759 Great. 1239 00:52:51,760 --> 00:52:55,950 Any questions on Chinchilla? 1240 00:52:55,949 --> 00:52:56,789 Great. 1241 00:52:56,789 --> 00:52:58,099 Oh sorry. 1242 00:52:58,099 --> 00:53:01,319 In practice, how expensive is inference for these models 1243 00:53:01,320 --> 00:53:03,390 relative to training? 1244 00:53:03,389 --> 00:53:05,056 Actually, very expensive. 1245 00:53:05,056 --> 00:53:07,139 I will not talk about inference because that would 1246 00:53:07,139 --> 00:53:09,009 be another entire lecture. 1247 00:53:09,010 --> 00:53:11,520 But just think about ChatGPT where 1248 00:53:11,519 --> 00:53:14,079 they have I don't know how much it is now, 1249 00:53:14,079 --> 00:53:18,029 like 600 million people that use it. 1250 00:53:18,030 --> 00:53:22,470 Like, that's a lot. 1251 00:53:22,469 --> 00:53:23,139 Yeah. 1252 00:53:23,139 --> 00:53:24,519 So it's actually very expensive. 1253 00:53:24,519 --> 00:53:27,389 There's a lot of optimization you can do for inference though. 1254 00:53:27,389 --> 00:53:29,079 And that's an entire other lecture. 1255 00:53:29,079 --> 00:53:33,569 I'm going to skip that this time, but it's very interesting. 1256 00:53:33,570 --> 00:53:34,922 OK tunings. 1257 00:53:34,922 --> 00:53:36,630 As I said, there are many things that you 1258 00:53:36,630 --> 00:53:38,349 can answer with scaling laws. 1259 00:53:38,349 --> 00:53:40,920 I just try to give you two examples, 1260 00:53:40,920 --> 00:53:42,309 but really there are many things. 1261 00:53:42,309 --> 00:53:43,420 What data do you use. 1262 00:53:43,420 --> 00:53:46,650 What mixture-- what data mixing weighting you use. 1263 00:53:46,650 --> 00:53:49,019 The mixtures, that's what we talked about before. 1264 00:53:49,019 --> 00:53:51,210 What architecture you use, whether you should make 1265 00:53:51,210 --> 00:53:54,030 your models wider or deeper? 1266 00:53:54,030 --> 00:53:56,380 Should you be paying for more GPUs 1267 00:53:56,380 --> 00:53:58,809 or actually collecting more data? 1268 00:53:58,809 --> 00:54:00,549 All these things are things you can try 1269 00:54:00,550 --> 00:54:03,160 to answer with scaling laws. 1270 00:54:03,159 --> 00:54:05,629 One thing I want to say is the bitter lesson. 1271 00:54:05,630 --> 00:54:08,320 If you ever heard of Richard Sutton, 1272 00:54:08,320 --> 00:54:12,880 very famous blog post in 2019, what he realized, 1273 00:54:12,880 --> 00:54:16,630 which I think not enough people realize, 1274 00:54:16,630 --> 00:54:19,900 I didn't-- definitely did not realize at that time, 1275 00:54:19,900 --> 00:54:23,050 is that once you see these type of scaling laws you know that 1276 00:54:23,050 --> 00:54:26,240 the more compute you have, the better models you will get. 1277 00:54:26,239 --> 00:54:28,159 So with scale, you will get better model. 1278 00:54:28,159 --> 00:54:30,909 And you also know by Moore's law or these type 1279 00:54:30,909 --> 00:54:33,099 of variants of Moore's law that you will always 1280 00:54:33,099 --> 00:54:34,150 have better compute. 1281 00:54:34,150 --> 00:54:36,940 Then the only thing that matters is just 1282 00:54:36,940 --> 00:54:40,010 to have architectures that can leverage computation. 1283 00:54:40,010 --> 00:54:44,110 So what matters is basically systems data and less 1284 00:54:44,110 --> 00:54:46,240 so the architecture, like the small architecture 1285 00:54:46,239 --> 00:54:49,719 differences like, your activation and things like this. 1286 00:54:49,719 --> 00:54:52,269 So I think that's one of the reasons why most of research 1287 00:54:52,269 --> 00:54:56,809 focuses on some things that for industry matters less. 1288 00:54:56,809 --> 00:54:58,329 And I was one of those researchers 1289 00:54:58,329 --> 00:55:02,349 for a large part of my career. 1290 00:55:02,349 --> 00:55:04,839 So don't spend time over complicating. 1291 00:55:04,840 --> 00:55:07,250 Do the simple things, do it well. 1292 00:55:07,250 --> 00:55:08,119 See all them. 1293 00:55:08,119 --> 00:55:12,670 That's really what OpenAI taught us with ChatGPT and with all 1294 00:55:12,670 --> 00:55:15,460 the GPTs before. 1295 00:55:15,460 --> 00:55:18,949 OK, I want to give you some back of the envelope computation. 1296 00:55:18,949 --> 00:55:20,869 So I might be off by a few factors here, 1297 00:55:20,869 --> 00:55:23,710 but I just want to give you a sense of how costly it is 1298 00:55:23,710 --> 00:55:25,360 to train some of these models. 1299 00:55:25,360 --> 00:55:26,950 I'll give us an example. 1300 00:55:26,949 --> 00:55:30,309 llama3 400b which is currently the best open source model that 1301 00:55:30,309 --> 00:55:31,659 you can get. 1302 00:55:31,659 --> 00:55:35,000 It was trained on 15.6 tokens. 1303 00:55:35,000 --> 00:55:37,880 It has 405 billion parameters. 1304 00:55:37,880 --> 00:55:39,490 So just now that you know what is 1305 00:55:39,489 --> 00:55:43,289 like this optimal tokens per parameter, that's around 40. 1306 00:55:43,289 --> 00:55:45,440 So that's a little bit more than Chinchilla, 1307 00:55:45,440 --> 00:55:50,630 but less than this like inference optimal model. 1308 00:55:50,630 --> 00:55:53,559 So they went for training optimallity 1309 00:55:53,559 --> 00:55:55,130 Flops for this model. 1310 00:55:55,130 --> 00:55:57,760 So one simple way to compute flops 1311 00:55:57,760 --> 00:56:00,850 is 6 times the number of parameters, 1312 00:56:00,849 --> 00:56:03,009 times the number of data that you train on. 1313 00:56:03,010 --> 00:56:04,880 So if you do the simple calculation here, 1314 00:56:04,880 --> 00:56:07,640 it's 3.8 e25 flops. 1315 00:56:07,639 --> 00:56:09,279 The reason why this is important is 1316 00:56:09,280 --> 00:56:11,155 that if you follow it a little bit, the news, 1317 00:56:11,155 --> 00:56:13,540 there's an executive order from Biden that basically 1318 00:56:13,539 --> 00:56:19,599 says that once you have one e26 parameters, sorry, flops, then 1319 00:56:19,599 --> 00:56:21,380 you have special scrutiny on your models. 1320 00:56:21,380 --> 00:56:23,690 So they went to 2X less than that. 1321 00:56:23,690 --> 00:56:25,480 So they really went right below this 1322 00:56:25,480 --> 00:56:27,190 to not have special scrutiny. 1323 00:56:27,190 --> 00:56:28,575 So 3.8. 1324 00:56:28,574 --> 00:56:30,699 I might be off by a little bit, but it's definitely 1325 00:56:30,699 --> 00:56:36,369 under the 1 e26 1326 00:56:36,369 --> 00:56:41,719 So parameter p is parameters n is data, number of tokens. 1327 00:56:41,719 --> 00:56:46,059 This is just an approximation. 1328 00:56:46,059 --> 00:56:48,099 Yeah. 1329 00:56:48,099 --> 00:56:49,029 OK. 1330 00:56:49,030 --> 00:56:55,690 Compute and we know that they trained on 16,000 h100s and we 1331 00:56:55,690 --> 00:56:58,480 know the throughput they set it to. 1332 00:56:58,480 --> 00:57:02,789 So if you do the computation, it takes around 70 days 1333 00:57:02,789 --> 00:57:05,279 or 26 million GPU hours. 1334 00:57:05,280 --> 00:57:08,500 At least that's what my back of the envelope computation. 1335 00:57:08,500 --> 00:57:10,619 They actually said that they use 30 million 1336 00:57:10,619 --> 00:57:13,710 instead of 26 million GPU hours. 1337 00:57:13,710 --> 00:57:17,416 So maybe they had some challenges. 1338 00:57:17,416 --> 00:57:18,250 I don't really know. 1339 00:57:18,250 --> 00:57:20,349 But if you follow the simple computation, 1340 00:57:20,349 --> 00:57:22,710 it's around 70 days. 1341 00:57:22,710 --> 00:57:24,240 Cost. 1342 00:57:24,239 --> 00:57:27,099 I mean this it's hard to approximate, 1343 00:57:27,099 --> 00:57:29,319 but I'm just going to say it's, kind of, the rent. 1344 00:57:29,320 --> 00:57:33,720 Like, what if I wanted to rent H100, that many H 100 1345 00:57:33,719 --> 00:57:36,569 for that many days, how much will I pay? 1346 00:57:36,570 --> 00:57:41,100 H100 a lower bound on the renting costs of H100 1347 00:57:41,099 --> 00:57:42,539 is around two hours-- 1348 00:57:42,539 --> 00:57:43,900 $2 per hour. 1349 00:57:43,900 --> 00:57:48,119 So if you multiply this by 26,000,000 hours, 1350 00:57:48,119 --> 00:57:50,980 you get $52 million. 1351 00:57:50,980 --> 00:57:52,960 So they probably pay less than that, 1352 00:57:52,960 --> 00:57:58,000 but not actually much less because all these services 1353 00:57:58,000 --> 00:58:00,409 that actually rent GPUs, they don't make that much money. 1354 00:58:00,409 --> 00:58:04,029 So it's probably slightly less, but not that much less. 1355 00:58:04,030 --> 00:58:10,586 Now salary I said 50 employees, 500k per year. 1356 00:58:10,586 --> 00:58:12,170 Yeah it's probably the right ballpark. 1357 00:58:12,170 --> 00:58:13,450 $25 million. 1358 00:58:13,449 --> 00:58:17,529 So if you put altogether around $75 million 1359 00:58:17,530 --> 00:58:21,040 for training this llama model. 1360 00:58:21,039 --> 00:58:22,579 I'm probably off by like 10 million, 1361 00:58:22,579 --> 00:58:27,340 but that's kind of right ballpark. 1362 00:58:27,340 --> 00:58:29,140 Carbon emitted. 1363 00:58:29,139 --> 00:58:32,144 A lot of people might ask like also the cost is not 1364 00:58:32,144 --> 00:58:33,519 the only thing that is important. 1365 00:58:33,519 --> 00:58:35,650 So I did the computation. 1366 00:58:35,650 --> 00:58:42,860 It's around 4000 tons of CO2 equivalent. 1367 00:58:42,860 --> 00:58:45,430 That is actually only 2000 return tickets 1368 00:58:45,429 --> 00:58:47,599 from JFK to London. 1369 00:58:47,599 --> 00:58:51,819 So right now carbon emitted is actually not-- 1370 00:58:51,820 --> 00:58:56,600 I mean, it's huge, but it's not meaningful yet. 1371 00:58:56,599 --> 00:59:01,759 I think in maybe GPT6, GPT7, once you multiply this 1372 00:59:01,760 --> 00:59:04,075 by 100, that might become a real issue. 1373 00:59:04,074 --> 00:59:07,219 Right now it's still not, I think, 1374 00:59:07,219 --> 00:59:09,409 an issue in the grand scheme of things. 1375 00:59:09,409 --> 00:59:12,649 Next model the way you should be thinking about these models is 1376 00:59:12,650 --> 00:59:16,220 that every new generation, the number of flops essentially 1377 00:59:16,219 --> 00:59:19,339 multiplies 10x, or at least that's what they try if they 1378 00:59:19,340 --> 00:59:20,490 have enough energy. 1379 00:59:20,489 --> 00:59:23,279 And if they can buy enough GPUs. 1380 00:59:23,280 --> 00:59:23,780 Great. 1381 00:59:23,780 --> 00:59:26,140 Any question on these back of the envelope math. 1382 00:59:29,699 --> 00:59:30,199 No. 1383 00:59:30,199 --> 00:59:31,939 OK. 1384 00:59:31,940 --> 00:59:34,829 So now we talked about pretraining, 1385 00:59:34,829 --> 00:59:36,829 I wanted to also chat about systems 1386 00:59:36,829 --> 00:59:39,319 because now we know compute is really important so there's 1387 00:59:39,320 --> 00:59:41,600 a question of how do you optimize the-- 1388 00:59:41,599 --> 00:59:43,099 how do you optimize the compute? 1389 00:59:43,099 --> 00:59:45,019 I will leave that for the end because I'm not 1390 00:59:45,019 --> 00:59:46,400 sure how much time we will have. 1391 00:59:46,400 --> 00:59:48,289 I think it's important, but hopefully I'll 1392 00:59:48,289 --> 00:59:50,409 be able to talk about it later. 1393 00:59:50,409 --> 00:59:52,759 It's slightly different than what we've 1394 00:59:52,760 --> 00:59:54,030 been talking about right now. 1395 00:59:54,030 --> 00:59:56,450 So I'll move on to post-training for now. 1396 00:59:56,449 --> 00:59:59,809 So the task of post-training, the reason why 1397 00:59:59,809 --> 01:00:01,789 we need to do post training is, as I told you 1398 01:00:01,789 --> 01:00:06,090 before, it's to make AI assistants. 1399 01:00:06,090 --> 01:00:09,800 So language modeling is not really the thing 1400 01:00:09,800 --> 01:00:12,530 that you want when you have an AI assistant. 1401 01:00:12,530 --> 01:00:14,930 For example, if you ask to GPT3, which 1402 01:00:14,929 --> 01:00:16,949 is a purely language model-- 1403 01:00:16,949 --> 01:00:20,179 a pure language model, not a non-aligned one. 1404 01:00:20,179 --> 01:00:22,639 If you ask a question explain the moon landing 1405 01:00:22,639 --> 01:00:26,210 to a six-year-old, the completion that you would get 1406 01:00:26,210 --> 01:00:29,429 is something explain the theory of gravity to a six-year-old. 1407 01:00:29,429 --> 01:00:31,710 Because what it learned is that on internet, 1408 01:00:31,710 --> 01:00:33,590 if you have one question, you usually 1409 01:00:33,590 --> 01:00:36,860 have maybe another bullet point of other similar questions 1410 01:00:36,860 --> 01:00:39,380 you don't usually have question and then answer later. 1411 01:00:39,380 --> 01:00:42,740 This is not what you want from an AI assistant. 1412 01:00:42,739 --> 01:00:46,069 So how do we do this alignment, which 1413 01:00:46,070 --> 01:00:49,730 is this post training and making these models assistants? 1414 01:00:49,730 --> 01:00:52,969 So the goal of this alignment is to basically get 1415 01:00:52,969 --> 01:00:55,549 LLMs follow the instructions that 1416 01:00:55,550 --> 01:01:00,350 are given by users and maybe some designers, 1417 01:01:00,349 --> 01:01:02,179 kind of, desires. 1418 01:01:02,179 --> 01:01:04,019 So think about motivation. 1419 01:01:04,019 --> 01:01:06,019 You don't want the model-- like OpenAI 1420 01:01:06,019 --> 01:01:09,949 doesn't want the model to say stuff that is very toxic. 1421 01:01:09,949 --> 01:01:12,199 So here you see on the left-hand side 1422 01:01:12,199 --> 01:01:15,569 that when you ask a question, it actually provides a real answer. 1423 01:01:15,570 --> 01:01:17,720 So it's not like before the LLM. 1424 01:01:17,719 --> 01:01:20,569 And on the right-hand side, you see that it would-- 1425 01:01:20,570 --> 01:01:25,039 if you ask to write a tweet describing how a certain part 1426 01:01:25,039 --> 01:01:29,900 of the population are evil, it will say that it cannot do that. 1427 01:01:29,900 --> 01:01:32,840 So that's kind of this alignment. 1428 01:01:32,840 --> 01:01:38,000 The background here is that basically the data 1429 01:01:38,000 --> 01:01:41,809 that you want for training some of these models is-- 1430 01:01:41,809 --> 01:01:42,960 like, we know what we want. 1431 01:01:42,960 --> 01:01:44,960 Which is just asking humans, this is a question, 1432 01:01:44,960 --> 01:01:46,340 this is the answer that you want. 1433 01:01:46,340 --> 01:01:48,965 But the thing is that it's very expensive to collect that data, 1434 01:01:48,965 --> 01:01:51,350 and it's hard to find it online. 1435 01:01:51,349 --> 01:01:54,659 In contrast, pretraining data is not what you want, 1436 01:01:54,659 --> 01:01:56,359 but there's a lot of it. 1437 01:01:56,360 --> 01:01:59,630 So what we will do, or the main idea is simply 1438 01:01:59,630 --> 01:02:01,460 take a pretrained large language model 1439 01:02:01,460 --> 01:02:03,710 pretrained on all of internet and then just fine tune. 1440 01:02:03,710 --> 01:02:06,335 So you just change a little bit the weights on the type of data 1441 01:02:06,335 --> 01:02:07,380 that you actually want. 1442 01:02:07,380 --> 01:02:08,930 And hopefully given it, you already 1443 01:02:08,929 --> 01:02:10,319 pretrained it on all of internet, 1444 01:02:10,320 --> 01:02:13,250 it basically learns or knows how to speak in English 1445 01:02:13,250 --> 01:02:18,320 and knows standard language syntax 1446 01:02:18,320 --> 01:02:23,309 then you can really fine tune it with very little data. 1447 01:02:23,309 --> 01:02:24,460 OK, SFT. 1448 01:02:24,460 --> 01:02:27,534 So Supervised Fine Tuning is really exactly what I just said. 1449 01:02:27,534 --> 01:02:29,659 Which is the idea of fine-tuning the large language 1450 01:02:29,659 --> 01:02:33,259 model on basically the desired answers that 1451 01:02:33,260 --> 01:02:35,480 are collected from humans. 1452 01:02:35,480 --> 01:02:37,769 So why is it called supervised fine tuning? 1453 01:02:37,769 --> 01:02:41,199 Because you basically want to do language modeling on the real 1454 01:02:41,199 --> 01:02:41,699 answers. 1455 01:02:41,699 --> 01:02:44,449 So language modeling is this like next word prediction, 1456 01:02:44,449 --> 01:02:45,869 and that's the fine tuning part. 1457 01:02:45,869 --> 01:02:48,809 And then you want to do it on desired answers given by humans 1458 01:02:48,809 --> 01:02:51,110 so that's why we call it supervised. 1459 01:02:51,110 --> 01:02:52,860 So how do we collect this data? 1460 01:02:52,860 --> 01:02:54,150 Well, I just said it. 1461 01:02:54,150 --> 01:02:57,039 You just ask humans to tell you this 1462 01:02:57,039 --> 01:02:59,469 is a question this is the answer that you would 1463 01:02:59,469 --> 01:03:00,829 want from some of these models. 1464 01:03:00,829 --> 01:03:03,219 So this is an example. 1465 01:03:03,219 --> 01:03:04,969 I can't read very well on my computer, 1466 01:03:04,969 --> 01:03:08,444 but my kid needs to do a science-- 1467 01:03:08,445 --> 01:03:09,650 no let's read this one. 1468 01:03:09,650 --> 01:03:11,680 Can you write a short introduction 1469 01:03:11,679 --> 01:03:13,690 about the relevance of the term monopsony? 1470 01:03:13,690 --> 01:03:15,700 And then it says monopsony refers to a market 1471 01:03:15,699 --> 01:03:16,824 structure, blah blah, blah. 1472 01:03:16,824 --> 01:03:19,119 And that's a human network there. 1473 01:03:19,119 --> 01:03:20,779 So, actually, this is Open Assistant, 1474 01:03:20,780 --> 01:03:27,970 which was a way to collect data online by humans. 1475 01:03:27,969 --> 01:03:31,359 So this type of supervised fine tuning or alignment 1476 01:03:31,360 --> 01:03:33,670 is really the key of ChatGPT. 1477 01:03:33,670 --> 01:03:37,780 This is what made the big jump from GPT 3, which was mostly 1478 01:03:37,780 --> 01:03:40,120 something that was known by AI researchers 1479 01:03:40,119 --> 01:03:44,299 to ChatGPT, which became known by basically everyone. 1480 01:03:46,900 --> 01:03:51,789 So the problem with human data is 1481 01:03:51,789 --> 01:03:56,300 that it's very slow to collect and very expensive. 1482 01:03:56,300 --> 01:04:00,640 So one possible simple idea is to use 1483 01:04:00,639 --> 01:04:03,250 LLMs to scale data collection. 1484 01:04:03,250 --> 01:04:06,739 So that's exactly what we did with Alpaca one year ago. 1485 01:04:06,739 --> 01:04:09,069 What we did is that we asked humans, 1486 01:04:09,070 --> 01:04:11,990 so we use a data set of human question answers. 1487 01:04:11,989 --> 01:04:15,069 So there were 175 question answers here, 1488 01:04:15,070 --> 01:04:16,760 and we asked the best model at the time, 1489 01:04:16,760 --> 01:04:21,100 so text-davinci 003 to basically generate many more of these 1490 01:04:21,099 --> 01:04:22,250 question and answers. 1491 01:04:22,250 --> 01:04:25,219 So all we did is, this is what humans would write now, 1492 01:04:25,219 --> 01:04:27,500 write similar answers and similar questions. 1493 01:04:27,500 --> 01:04:32,389 And we collected 52,000 LLM-generated question answers. 1494 01:04:32,389 --> 01:04:34,667 And then what we did is simply we took llama 7B, 1495 01:04:34,668 --> 01:04:36,710 which was the best pre-trained model at the time. 1496 01:04:36,710 --> 01:04:39,139 And we just fine tuned this with supervised fine tuning, 1497 01:04:39,139 --> 01:04:39,889 as I told you. 1498 01:04:39,889 --> 01:04:44,769 And that's how we got the Alpaca 7B model. 1499 01:04:44,769 --> 01:04:47,090 And this is the type of data that we collected. 1500 01:04:47,090 --> 01:04:49,460 So things like what does algorithm mean? 1501 01:04:49,460 --> 01:04:53,440 And algorithm is a step by step set of instructions 1502 01:04:53,440 --> 01:04:55,950 you use to solve a problem or achieve a goal, blah, blah, 1503 01:04:55,949 --> 01:04:56,449 blah, blah. 1504 01:04:56,449 --> 01:04:58,980 So the data is not actually-- it's actually pretty good, 1505 01:04:58,980 --> 01:05:02,119 given that it was LLM generated by LLMs from essentially two 1506 01:05:02,119 --> 01:05:04,880 generations ago. 1507 01:05:04,880 --> 01:05:07,280 So that really started at least for us 1508 01:05:07,280 --> 01:05:10,340 as an academic replication of ChatGPT. 1509 01:05:10,340 --> 01:05:12,980 Now it really-- there's a big field 1510 01:05:12,980 --> 01:05:15,469 of synthetic data generation of how 1511 01:05:15,469 --> 01:05:21,139 to use LLMs to basically make development of LLMs faster. 1512 01:05:21,139 --> 01:05:24,319 And basically by decreasing the amount of human hours that 1513 01:05:24,320 --> 01:05:26,809 you need. 1514 01:05:26,809 --> 01:05:28,519 Quantity of data. 1515 01:05:28,519 --> 01:05:31,610 So we talked about what type of data and how we collect it. 1516 01:05:31,610 --> 01:05:33,800 One thing which is surprising with SFT 1517 01:05:33,800 --> 01:05:36,260 is that you don't need that much data. 1518 01:05:36,260 --> 01:05:38,940 So what this paper showed this is called LIMA, 1519 01:05:38,940 --> 01:05:43,340 is that if you scale the amount of data that you use from 1520 01:05:43,340 --> 01:05:46,710 supervised fine tuning from 2000 to 32,000, 1521 01:05:46,710 --> 01:05:47,970 it really doesn't help much. 1522 01:05:47,969 --> 01:05:49,949 So here scaling laws definitely don't help. 1523 01:05:49,949 --> 01:05:55,279 And so the intuition here is that all you learn 1524 01:05:55,280 --> 01:05:58,980 is you learn how to format your desired answers. 1525 01:05:58,980 --> 01:06:02,510 Another way of saying it is that your pre-trained models, they 1526 01:06:02,510 --> 01:06:04,970 essentially model the distribution of every user 1527 01:06:04,969 --> 01:06:07,529 on internet, one that might write bullet points, 1528 01:06:07,530 --> 01:06:09,590 another one that might answer question-- answer 1529 01:06:09,590 --> 01:06:10,980 question with an answer. 1530 01:06:10,980 --> 01:06:13,469 So all you tell your model is like, wait, 1531 01:06:13,469 --> 01:06:14,959 you should actually be optimizing 1532 01:06:14,960 --> 01:06:17,340 more for this type of user than another one. 1533 01:06:17,340 --> 01:06:18,980 So you're not actually teaching it-- 1534 01:06:18,980 --> 01:06:23,539 you're not teaching anything through this SFT, so 1535 01:06:23,539 --> 01:06:25,099 supervised fine tuning, all you do 1536 01:06:25,099 --> 01:06:28,309 is you tell the model to optimize for one type of user 1537 01:06:28,309 --> 01:06:30,980 that it saw already in a pretrained data set. 1538 01:06:30,980 --> 01:06:33,559 So the knowledge is already in the pretrained LLM 1539 01:06:33,559 --> 01:06:37,529 and you basically just specialize to one type of user. 1540 01:06:37,530 --> 01:06:38,030 Great. 1541 01:06:38,030 --> 01:06:40,970 Any question on SFT? 1542 01:06:40,969 --> 01:06:41,899 Yes. 1543 01:06:41,900 --> 01:06:45,260 So I know it's a big issue with synthetic data 1544 01:06:45,260 --> 01:06:49,607 where if you keep generating data from the same distribution, 1545 01:06:49,606 --> 01:06:51,690 eventually you're not learning a new distribution, 1546 01:06:51,690 --> 01:06:52,920 you're essentially playing with it. 1547 01:06:52,920 --> 01:06:53,930 Just bootstrapping that. 1548 01:06:53,929 --> 01:06:55,069 Yeah. 1549 01:06:55,070 --> 01:06:57,870 Surely you can't scale that forever, right. 1550 01:06:57,869 --> 01:06:59,509 You can't keep going on and generating 1551 01:06:59,510 --> 01:07:00,570 from the same distribution. 1552 01:07:00,570 --> 01:07:01,830 You hope to learned something new. 1553 01:07:01,829 --> 01:07:02,329 Yeah. 1554 01:07:02,329 --> 01:07:05,099 So are there-- it's an active area of research 1555 01:07:05,099 --> 01:07:06,739 but any thoughts that you have around 1556 01:07:06,739 --> 01:07:10,939 how people are maybe thinking around this and better ways 1557 01:07:10,940 --> 01:07:11,670 to bootstrap? 1558 01:07:11,670 --> 01:07:15,188 Or to give up on this idea and realize that the chart shows 1559 01:07:15,188 --> 01:07:17,480 you don't need that many so just get humans to generate 1560 01:07:17,480 --> 01:07:19,190 2000 really good prompts. 1561 01:07:19,190 --> 01:07:20,269 Yeah. 1562 01:07:20,269 --> 01:07:21,780 So that's a very good question. 1563 01:07:21,780 --> 01:07:23,320 So for the data stuff, so I'm saying 1564 01:07:23,320 --> 01:07:25,070 it's not that important for SFT, but there 1565 01:07:25,070 --> 01:07:28,190 will be another thing we'll talk about right after where actually 1566 01:07:28,190 --> 01:07:29,720 data does matter. 1567 01:07:29,719 --> 01:07:33,980 My intuition based on not that much empirical results 1568 01:07:33,980 --> 01:07:38,519 is that you can still get, even though you use your LLMs, 1569 01:07:38,519 --> 01:07:40,610 if you use purely LLM generated text 1570 01:07:40,610 --> 01:07:43,470 and you do that for like three or four generations of LLMs, 1571 01:07:43,469 --> 01:07:45,829 I agree with you that probably you won't improve much. 1572 01:07:45,829 --> 01:07:48,860 But for me what is important is how do you use human in the loop 1573 01:07:48,860 --> 01:07:49,860 with LLMs? 1574 01:07:49,860 --> 01:07:53,065 Not purely LLMs, not purely humans, 1575 01:07:53,065 --> 01:07:54,440 but maybe what you can do is just 1576 01:07:54,440 --> 01:07:56,510 have the model regenerate some new text 1577 01:07:56,510 --> 01:07:59,220 and just humans write a few edits. 1578 01:07:59,219 --> 01:08:01,926 Edits are much faster than writing the entire text. 1579 01:08:01,927 --> 01:08:04,260 And I think that if you have that type of collaboration, 1580 01:08:04,260 --> 01:08:07,050 then from an information theoretical point of view, 1581 01:08:07,050 --> 01:08:09,120 you still get additional information, 1582 01:08:09,119 --> 01:08:11,609 but you're still much faster than if you use humans. 1583 01:08:11,610 --> 01:08:13,099 And I think that as a field we'll 1584 01:08:13,099 --> 01:08:17,059 probably move towards these type of things, which is really 1585 01:08:17,060 --> 01:08:20,833 just finding the examples that are important and asking humans. 1586 01:08:20,832 --> 01:08:22,250 It's kind of active learning, just 1587 01:08:22,250 --> 01:08:28,239 asking humans exactly when you need to get their inputs. 1588 01:08:28,239 --> 01:08:28,739 Yes. 1589 01:08:28,739 --> 01:08:30,710 Do we train with the same loss function 1590 01:08:30,710 --> 01:08:32,750 and the same general training algorithm 1591 01:08:32,750 --> 01:08:34,310 for the supervised fine tuning bit 1592 01:08:34,310 --> 01:08:36,260 as we do for the pretraining? 1593 01:08:36,260 --> 01:08:39,079 Because the examples you showed, I 1594 01:08:39,079 --> 01:08:43,079 think the important thing of the good examples 1595 01:08:43,079 --> 01:08:45,180 is like super factually accurate. 1596 01:08:45,180 --> 01:08:46,939 Like there's these more complex things 1597 01:08:46,939 --> 01:08:48,740 and it's still just like [INAUDIBLE]. 1598 01:08:48,739 --> 01:08:49,380 Same loss. 1599 01:08:49,380 --> 01:08:50,420 So that's why here-- 1600 01:08:50,420 --> 01:08:52,527 yeah, I didn't-- maybe didn't emphasize enough. 1601 01:08:52,527 --> 01:08:53,819 This is just language modeling. 1602 01:08:53,819 --> 01:08:56,710 Fine tune the LLM with language model and the desired answers. 1603 01:08:56,710 --> 01:08:59,069 So this is literally the same loss. 1604 01:08:59,069 --> 01:09:01,840 It will be different in two seconds, 1605 01:09:01,840 --> 01:09:04,260 but the first step of SFT is literally 1606 01:09:04,260 --> 01:09:06,229 the same loss where you just say, OK, I 1607 01:09:06,229 --> 01:09:08,380 want to actually specialize on that type of data. 1608 01:09:08,380 --> 01:09:10,672 So there's even a question of what is pretraining, 1609 01:09:10,672 --> 01:09:11,590 what is post-training? 1610 01:09:11,590 --> 01:09:13,050 Because, in reality, it's just like a different data 1611 01:09:13,050 --> 01:09:13,840 that you use. 1612 01:09:13,840 --> 01:09:16,465 The reason why we usually call it post-training is that the way 1613 01:09:16,465 --> 01:09:18,989 we collect that data is very different. 1614 01:09:18,989 --> 01:09:20,969 Great, great questions. 1615 01:09:20,970 --> 01:09:22,079 Yes. 1616 01:09:22,079 --> 01:09:24,000 Maybe it's the same question, but why would 1617 01:09:24,000 --> 01:09:28,260 these 2000 examples have such a overweighted influence 1618 01:09:28,260 --> 01:09:30,220 on fine tuning? 1619 01:09:30,220 --> 01:09:31,487 So that's why we-- 1620 01:09:31,487 --> 01:09:33,779 also that's another reason why we call it post-training 1621 01:09:33,779 --> 01:09:35,770 is that we use different type of hyperparameters. 1622 01:09:35,770 --> 01:09:37,260 So, I told you basically at the end 1623 01:09:37,260 --> 01:09:38,802 of pretraining you essentially end up 1624 01:09:38,801 --> 01:09:40,195 with a learning rate of 0. 1625 01:09:40,195 --> 01:09:42,278 Here, you're going to increase your learning rate. 1626 01:09:42,279 --> 01:09:44,250 So like 1e minus 5, 1e minus-- yeah. 1627 01:09:44,250 --> 01:09:49,210 And so the way that you give to them is actually different. 1628 01:09:52,569 --> 01:09:54,010 OK. 1629 01:09:54,010 --> 01:09:57,820 Second step or second part of this post training 1630 01:09:57,819 --> 01:10:00,009 is what we call reinforcement learning 1631 01:10:00,010 --> 01:10:02,380 from human feedback or RLHF. 1632 01:10:02,380 --> 01:10:05,109 Some of you might have heard of that. 1633 01:10:05,109 --> 01:10:09,189 The idea is that SFT has a problem, namely that you 1634 01:10:09,189 --> 01:10:12,609 do behavioral cloning, which means that you just try to clone 1635 01:10:12,609 --> 01:10:14,599 what the humans would say. 1636 01:10:14,600 --> 01:10:16,539 And that has many issues. 1637 01:10:16,539 --> 01:10:19,220 One of them is that you're bound by human abilities. 1638 01:10:19,220 --> 01:10:26,168 So if-- humans actually humans won't generate the things 1639 01:10:26,167 --> 01:10:28,460 that they think is actually the best thing to generate. 1640 01:10:28,460 --> 01:10:30,591 So if you ask me to write a book, 1641 01:10:30,591 --> 01:10:32,299 I mean, I can definitely enjoy your book. 1642 01:10:32,300 --> 01:10:34,489 I can probably say one book is better than another, 1643 01:10:34,489 --> 01:10:37,072 but I'm definitely not going to be as good as writing the book 1644 01:10:37,073 --> 01:10:37,960 that I want to read. 1645 01:10:37,960 --> 01:10:39,880 So you're going to be bound by the human ability 1646 01:10:39,880 --> 01:10:42,297 to generate things, even though the humans might be better 1647 01:10:42,296 --> 01:10:43,750 at distinguishing between things. 1648 01:10:43,750 --> 01:10:44,800 That's one issue. 1649 01:10:44,800 --> 01:10:47,710 Issue number two, I find that actually pretty interesting 1650 01:10:47,710 --> 01:10:49,000 is that it-- 1651 01:10:49,000 --> 01:10:51,310 if you ever heard of the word hallucination. so this 1652 01:10:51,310 --> 01:10:55,820 is LLMs generating fake-- like false information. 1653 01:10:55,819 --> 01:10:57,949 Hallucination might-- at least people 1654 01:10:57,949 --> 01:11:02,039 have hypothesized that can come from the supervised fine tuning 1655 01:11:02,039 --> 01:11:06,239 even if you do supervised fine tuning on data that is correct. 1656 01:11:06,239 --> 01:11:09,559 And the reason why that is is that if-- 1657 01:11:09,560 --> 01:11:13,190 given I told you that basically SFT is with very little data. 1658 01:11:13,189 --> 01:11:15,859 And it's with data that the model 1659 01:11:15,859 --> 01:11:17,339 doesn't learn anything new. 1660 01:11:17,340 --> 01:11:21,440 So what if the human gives an answer that the model didn't 1661 01:11:21,439 --> 01:11:23,329 know was true. 1662 01:11:23,329 --> 01:11:26,269 From the model perspective, the human basically 1663 01:11:26,270 --> 01:11:30,890 is telling the model generate this thing that seems plausible 1664 01:11:30,890 --> 01:11:34,190 but actually have no idea if it's true or not. 1665 01:11:34,189 --> 01:11:36,569 So just to give you a very concrete example, 1666 01:11:36,569 --> 01:11:39,090 if we go back to this monopsony example, 1667 01:11:39,090 --> 01:11:41,750 can you write blah blah blah about monopsony? 1668 01:11:41,750 --> 01:11:46,500 Imagine that the human wrote a reference on this type of book. 1669 01:11:46,500 --> 01:11:47,909 And that book might exist. 1670 01:11:47,909 --> 01:11:49,349 That might be a correct reference, 1671 01:11:49,350 --> 01:11:51,740 but what if the LLM never saw this reference 1672 01:11:51,739 --> 01:11:52,594 during pretraining. 1673 01:11:52,595 --> 01:11:54,720 Then it doesn't know that it's a correct reference. 1674 01:11:54,720 --> 01:11:56,300 So really what you tell the model 1675 01:11:56,300 --> 01:12:00,890 is to generate or make up some plausible sounding reference 1676 01:12:00,890 --> 01:12:03,770 rather than actually tell the real reference 1677 01:12:03,770 --> 01:12:05,480 that it saw during pretraining. 1678 01:12:05,479 --> 01:12:12,469 So hallucination might be caused by this SFT. 1679 01:12:12,470 --> 01:12:14,240 So that's problem number two. 1680 01:12:14,239 --> 01:12:15,559 Does that all make sense? 1681 01:12:15,560 --> 01:12:16,340 Great. 1682 01:12:16,340 --> 01:12:18,260 Problem number 3, price. 1683 01:12:18,260 --> 01:12:21,780 Generating the ideal answers is very pricey. 1684 01:12:21,779 --> 01:12:23,719 And that comes back to your question 1685 01:12:23,720 --> 01:12:26,510 of humans writing the entire answer is actually 1686 01:12:26,510 --> 01:12:28,489 pretty expensive. 1687 01:12:28,489 --> 01:12:30,329 So that's why RLHF comes in. 1688 01:12:30,329 --> 01:12:34,289 The idea is that instead of cloning the behaviors of humans, 1689 01:12:34,289 --> 01:12:37,100 we're going to maximize human preference. 1690 01:12:37,100 --> 01:12:39,500 And the way we're going to do that, so the pipeline, 1691 01:12:39,500 --> 01:12:42,449 is that for a certain-- for every instruction, 1692 01:12:42,449 --> 01:12:45,789 you're going to ask a model to generate two answers 1693 01:12:45,789 --> 01:12:48,269 and usually use a pretty good model. 1694 01:12:48,270 --> 01:12:52,890 So you usually don't use an LLM here, you use a SFT fine tune, 1695 01:12:52,890 --> 01:12:56,990 you use a fine tuned LLM already to give pretty good answers. 1696 01:12:56,989 --> 01:13:01,199 And then you ask labelers which of these two answers was better? 1697 01:13:01,199 --> 01:13:02,909 So select the preferred one. 1698 01:13:02,909 --> 01:13:05,279 And then with different types of algorithms, 1699 01:13:05,279 --> 01:13:07,699 we're going to talk about the algorithms, you just fine 1700 01:13:07,699 --> 01:13:10,010 tune the model to generate more of the green thing 1701 01:13:10,010 --> 01:13:10,920 than the red thing. 1702 01:13:10,920 --> 01:13:12,770 So more of the good stuff. 1703 01:13:12,770 --> 01:13:14,270 So now the question is how and we're 1704 01:13:14,270 --> 01:13:17,060 going to talk about that right now. 1705 01:13:17,060 --> 01:13:20,030 So there are two ways that we're going to talk about 1706 01:13:20,029 --> 01:13:23,119 and two that are mainly use in the community. 1707 01:13:23,119 --> 01:13:26,489 The first one is simply the idea of using reinforcement learning. 1708 01:13:26,489 --> 01:13:30,019 So hopefully you all know what reinforcement learning is now. 1709 01:13:30,020 --> 01:13:33,067 So when you think about using reinforcement learning, 1710 01:13:33,067 --> 01:13:35,149 one important question is like, what is the reward 1711 01:13:35,149 --> 01:13:36,449 that we're optimizing. 1712 01:13:36,449 --> 01:13:38,827 So in this case, there are really two options 1713 01:13:38,828 --> 01:13:39,870 that I could think about. 1714 01:13:39,869 --> 01:13:41,489 The first one, you could just say, 1715 01:13:41,489 --> 01:13:44,130 I'm going to compare the output generated by some baseline, 1716 01:13:44,130 --> 01:13:46,400 the output generated by my model. 1717 01:13:46,399 --> 01:13:49,609 And I'm just going to ask the human to say which one is better 1718 01:13:49,609 --> 01:13:51,929 and I'm going to use this as a reward. 1719 01:13:51,930 --> 01:13:53,400 So if I'm better than the baseline, 1720 01:13:53,399 --> 01:13:55,879 this is a plus 1, if not, it's a minus 1. 1721 01:13:55,880 --> 01:13:57,488 So now it's binary reward. 1722 01:13:57,488 --> 01:13:59,780 The problem with binary reward is that it's very sparse 1723 01:13:59,779 --> 01:14:01,939 and you don't get much information out of it. 1724 01:14:01,939 --> 01:14:04,469 Like maybe your answer was slightly better, 1725 01:14:04,470 --> 01:14:07,190 maybe it was like way better and you don't really 1726 01:14:07,189 --> 01:14:10,939 know from this how much better it was. 1727 01:14:10,939 --> 01:14:13,099 So option 2 is that you can train 1728 01:14:13,100 --> 01:14:16,730 what we call a reward model, which is simply a classifier. 1729 01:14:16,729 --> 01:14:19,759 So you use machine learning to classify 1730 01:14:19,760 --> 01:14:24,530 how much better two outputs are from the preference-- 1731 01:14:24,529 --> 01:14:26,929 from the perspective of the human. 1732 01:14:26,930 --> 01:14:29,750 So this is a little bit meta, but what you basically 1733 01:14:29,750 --> 01:14:31,279 do is that you train-- 1734 01:14:31,279 --> 01:14:37,909 you take a reward model, which is just a large la-- also 1735 01:14:37,909 --> 01:14:41,670 a large classifier, and you basically ask this reward model, 1736 01:14:41,670 --> 01:14:43,850 you give it the input and the actual output 1737 01:14:43,850 --> 01:14:45,800 that you have, one of the two outputs. 1738 01:14:45,800 --> 01:14:49,730 And you just exponentiate that so that's the softmax loss 1739 01:14:49,729 --> 01:14:50,849 that you all know about. 1740 01:14:50,850 --> 01:14:56,520 And now you divide by the exponentiated reward 1741 01:14:56,520 --> 01:14:58,645 on the first example-- 1742 01:14:58,645 --> 01:15:00,270 I'm sorry, on the first output and this 1743 01:15:00,270 --> 01:15:01,270 is on the second output. 1744 01:15:01,270 --> 01:15:02,740 And you basically train-- 1745 01:15:02,739 --> 01:15:05,789 so the reason why you do that is that you train your model, 1746 01:15:05,789 --> 01:15:07,470 you train this reward model to be 1747 01:15:07,470 --> 01:15:13,360 able to classify how much better one output is to another one. 1748 01:15:13,359 --> 01:15:16,380 So another slightly less convoluted way of saying it 1749 01:15:16,380 --> 01:15:19,020 is that your reward model will output 1750 01:15:19,020 --> 01:15:22,600 some reward that will be used as the logits of your softmax. 1751 01:15:22,600 --> 01:15:25,960 So now if you have high logits in your softmax, 1752 01:15:25,960 --> 01:15:32,760 it means that you highly likely this output is better. 1753 01:15:32,760 --> 01:15:34,829 So that's what we call Bradley-Terry model. 1754 01:15:34,829 --> 01:15:35,519 Yes. 1755 01:15:35,520 --> 01:15:36,937 Will this reward model [INAUDIBLE] 1756 01:15:36,936 --> 01:15:40,579 lower the entire output, or is it going to [INAUDIBLE]? 1757 01:15:40,579 --> 01:15:45,158 So this takes the entire-- 1758 01:15:45,158 --> 01:15:46,950 yeah, this takes the entire output at once. 1759 01:15:46,949 --> 01:15:48,782 So it takes all the input and all the output 1760 01:15:48,783 --> 01:15:50,420 and it gives one number. 1761 01:15:50,420 --> 01:15:51,529 Yes. 1762 01:15:51,529 --> 01:15:55,090 So [INAUDIBLE] reward model, where would the human be then? 1763 01:15:55,090 --> 01:15:55,590 Sorry. 1764 01:15:55,590 --> 01:15:58,190 With the reward model, where would the human be? 1765 01:15:58,189 --> 01:15:58,849 Like-- 1766 01:15:58,850 --> 01:16:00,230 I see. 1767 01:16:00,229 --> 01:16:01,159 OK sorry. 1768 01:16:01,159 --> 01:16:02,720 Maybe I wasn't clear. 1769 01:16:02,720 --> 01:16:08,449 You train this reward model to fit this green and red 1770 01:16:08,449 --> 01:16:09,869 preference from humans. 1771 01:16:09,869 --> 01:16:11,899 So basically you train a classifier 1772 01:16:11,899 --> 01:16:15,739 to say whether the humans prefer red or green. 1773 01:16:15,739 --> 01:16:18,319 But instead of using the binary reward, which 1774 01:16:18,319 --> 01:16:20,689 is what the human would tell you you basically use 1775 01:16:20,689 --> 01:16:23,189 the logits of the softmax. 1776 01:16:23,189 --> 01:16:26,609 And the thing with the logits is that logits are continuous. 1777 01:16:26,609 --> 01:16:29,059 So now you know that if your reward model said 1778 01:16:29,060 --> 01:16:31,550 it has high logits, then, in some ways, 1779 01:16:31,550 --> 01:16:36,960 the human highly preferred this answer to some other answer. 1780 01:16:36,960 --> 01:16:38,760 Great. 1781 01:16:38,760 --> 01:16:41,520 So as I just said, continuous information is better. 1782 01:16:41,520 --> 01:16:44,130 So that's what people use in practice or at least 1783 01:16:44,130 --> 01:16:45,539 used to use in practice. 1784 01:16:45,539 --> 01:16:48,180 I'll tell you about the other algorithm later. 1785 01:16:48,180 --> 01:16:50,490 So what do you do at the end is that you basically 1786 01:16:50,489 --> 01:16:53,590 try to just use reinforcement learning that you know about. 1787 01:16:53,590 --> 01:16:55,650 Now we know we have a reward. 1788 01:16:55,649 --> 01:16:58,079 What you sample through is the generation 1789 01:16:58,079 --> 01:16:59,970 from your large language model. 1790 01:16:59,970 --> 01:17:02,199 And then you just use some regularization term. 1791 01:17:02,199 --> 01:17:04,199 So the reason why we do this regularization term 1792 01:17:04,199 --> 01:17:06,729 is for avoiding what we call overoptimization. 1793 01:17:06,729 --> 01:17:08,339 So this reward model might not be 1794 01:17:08,340 --> 01:17:10,409 really represent-- might not perfectly 1795 01:17:10,409 --> 01:17:12,130 model human preferences. 1796 01:17:12,130 --> 01:17:14,039 So you don't want to maximize this thing 1797 01:17:14,039 --> 01:17:17,010 to essentially infinity. 1798 01:17:17,010 --> 01:17:22,710 And you do it using a PPO, which is a common reinforcement 1799 01:17:22,710 --> 01:17:24,359 learning algorithm. 1800 01:17:24,359 --> 01:17:27,309 One thing to note here, because it will be important for later, 1801 01:17:27,310 --> 01:17:32,730 is that when we use maximum likelihood-- 1802 01:17:32,729 --> 01:17:34,979 sorry, now the large language models 1803 01:17:34,979 --> 01:17:38,239 are actually a policy for your reinforcement learning. 1804 01:17:38,239 --> 01:17:41,179 It's not maximizing maximum likelihood anymore. 1805 01:17:41,180 --> 01:17:43,420 Which means that you're not modeling any distribution 1806 01:17:43,420 --> 01:17:43,960 anymore. 1807 01:17:43,960 --> 01:17:45,460 And the reason why this is important 1808 01:17:45,460 --> 01:17:48,699 is that models that went through this type of PPO 1809 01:17:48,699 --> 01:17:51,039 actually don't give you likelihoods 1810 01:17:51,039 --> 01:17:52,659 of text that are meaningful. 1811 01:17:52,659 --> 01:17:54,670 Because what you optimize them to do 1812 01:17:54,670 --> 01:17:56,829 is basically just optimize for generating 1813 01:17:56,829 --> 01:18:00,170 the most likely thing, not optimize for modeling, 1814 01:18:00,170 --> 01:18:02,510 all the answers that humans might say. 1815 01:18:02,510 --> 01:18:04,329 Another way of saying that is that there's 1816 01:18:04,329 --> 01:18:09,570 nothing that incentivizes here the model to not give 1817 01:18:09,570 --> 01:18:11,569 a single possible generation. 1818 01:18:11,569 --> 01:18:15,309 Nothing here says it's good if you have some distribution 1819 01:18:15,310 --> 01:18:18,007 with some entropy. 1820 01:18:18,006 --> 01:18:20,589 If you haven't followed, it's not that important but just good 1821 01:18:20,590 --> 01:18:22,000 to know. 1822 01:18:22,000 --> 01:18:23,140 Great. 1823 01:18:23,140 --> 01:18:27,350 So PPO is exactly what ChatGPT did originally. 1824 01:18:27,350 --> 01:18:30,370 So here is on their blog post on what 1825 01:18:30,369 --> 01:18:33,609 they have is step one do supervised fine tuning, which 1826 01:18:33,609 --> 01:18:34,789 now you all know about. 1827 01:18:34,789 --> 01:18:38,029 Step two, train a reward model on human preferences. 1828 01:18:38,029 --> 01:18:40,939 Step three, do PPO multiple steps, 1829 01:18:40,939 --> 01:18:43,279 which is where you see this blue arrow. 1830 01:18:43,279 --> 01:18:45,649 So you continue-- you train the model once with the PPO, 1831 01:18:45,649 --> 01:18:47,269 you collect new data, you continue. 1832 01:18:47,270 --> 01:18:50,590 And that's why-- and that's exactly what ChatGPT did. 1833 01:18:50,590 --> 01:18:52,150 And that was the big breakthrough 1834 01:18:52,149 --> 01:18:55,179 between GPT 3 and ChatGPT. 1835 01:18:55,180 --> 01:18:58,883 One thing to note is that PPO has many challenges. 1836 01:18:58,882 --> 01:19:00,550 Reinforcement learning is something that 1837 01:19:00,550 --> 01:19:02,420 is super nice theoretically. 1838 01:19:02,420 --> 01:19:03,880 In practice, anyone who ever worked 1839 01:19:03,880 --> 01:19:06,489 with reinforcement learning knows it's such a mess. 1840 01:19:06,489 --> 01:19:09,079 There's a lot of things like rollouts, outer loops, 1841 01:19:09,079 --> 01:19:11,949 clipping so many complications. 1842 01:19:11,949 --> 01:19:13,130 So it's messy. 1843 01:19:13,130 --> 01:19:15,904 This is the idealized PPO used for LLM settings, 1844 01:19:15,904 --> 01:19:17,529 so that's already much more complicated 1845 01:19:17,529 --> 01:19:19,029 than this expectation we saw before. 1846 01:19:19,029 --> 01:19:21,197 And in practice it's actually much more complicated. 1847 01:19:21,197 --> 01:19:23,600 So we have one implementation of it that we had to do, 1848 01:19:23,600 --> 01:19:25,160 and I'm not going to go through it. 1849 01:19:25,159 --> 01:19:27,189 But basically have so much stuff that you 1850 01:19:27,189 --> 01:19:29,109 have to think about when you implement 1851 01:19:29,109 --> 01:19:31,939 that type of PPO algorithm. 1852 01:19:31,939 --> 01:19:34,929 So you have clipping everywhere, you have a lot of complexities 1853 01:19:34,930 --> 01:19:37,480 and things are not well documented. 1854 01:19:37,479 --> 01:19:41,859 All this to say that we're going to there was a new method that 1855 01:19:41,859 --> 01:19:44,769 was proposed also from Stanford one year ago 1856 01:19:44,770 --> 01:19:49,690 called DPO, which is essentially a simplification of PPO. 1857 01:19:49,689 --> 01:19:53,619 And the way-- what they did or the idea that they have 1858 01:19:53,619 --> 01:19:56,265 is that instead of using reinforcement learning, 1859 01:19:56,265 --> 01:19:58,390 you can just maximize the probability of generating 1860 01:19:58,390 --> 01:20:00,307 the stuff that you like and minimizing 1861 01:20:00,307 --> 01:20:02,350 the probability of the stuff that you don't like. 1862 01:20:02,350 --> 01:20:05,180 So if you think about the human preference, the red and green, 1863 01:20:05,180 --> 01:20:08,800 maximize green, minimize red. 1864 01:20:08,800 --> 01:20:12,579 So the loss is actually this one where what you see 1865 01:20:12,579 --> 01:20:16,733 this is simply some log of the model. 1866 01:20:16,733 --> 01:20:19,150 So this is the likelihood of a model generating the things 1867 01:20:19,149 --> 01:20:23,259 that the human preferred, given the inputs. 1868 01:20:23,260 --> 01:20:25,630 And what you try to do is basically 1869 01:20:25,630 --> 01:20:30,369 maximize the likelihood of generating the things that you 1870 01:20:30,369 --> 01:20:33,909 like, minimize the likelihood of the things that you don't like. 1871 01:20:33,909 --> 01:20:36,739 All the rest of the terms here it's not too important. 1872 01:20:36,739 --> 01:20:39,949 It's actually really not that complicated to understand. 1873 01:20:39,949 --> 01:20:42,760 But at a high level, it's really just maximizing the things 1874 01:20:42,760 --> 01:20:45,369 you like, minimizing the rest. 1875 01:20:45,369 --> 01:20:49,699 And one thing to note, which I was going to say just here, 1876 01:20:49,699 --> 01:20:51,849 is that actually all the rest is chosen such 1877 01:20:51,850 --> 01:20:56,950 that the global minima of PPO and the global minima 1878 01:20:56,949 --> 01:20:59,889 of like this DPO, under some assumptions, 1879 01:20:59,890 --> 01:21:01,100 are essentially equivalent. 1880 01:21:01,100 --> 01:21:04,307 So this is the right thing to do mathematically. 1881 01:21:04,306 --> 01:21:06,139 I'm not going to go through the derivations, 1882 01:21:06,140 --> 01:21:08,050 but that's the right thing to do. 1883 01:21:08,050 --> 01:21:10,960 It's pretty different with PPO in the sense that now-- 1884 01:21:10,960 --> 01:21:13,579 with PPO, what you had to do is collect the human preferences, 1885 01:21:13,579 --> 01:21:16,237 then train a reward model with maximum likelihood, 1886 01:21:16,237 --> 01:21:17,569 then use reinforcement learning. 1887 01:21:17,569 --> 01:21:19,849 Now all you do is basically maximum likelihood. 1888 01:21:19,850 --> 01:21:20,539 Much simpler. 1889 01:21:20,539 --> 01:21:21,039 Yes. 1890 01:21:21,039 --> 01:21:21,609 I mean, yeah. 1891 01:21:21,609 --> 01:21:24,559 So it seems like this is A, much simpler and B, like, 1892 01:21:24,560 --> 01:21:27,220 what you would just intuitively do with [INAUDIBLE]? 1893 01:21:27,220 --> 01:21:29,720 Why did they start with this reward model. 1894 01:21:29,720 --> 01:21:31,880 Like what led them doing that? 1895 01:21:31,880 --> 01:21:33,279 I think it's a great question. 1896 01:21:33,279 --> 01:21:34,460 I don't really know. 1897 01:21:34,460 --> 01:21:35,805 What I can tell you is that. 1898 01:21:35,805 --> 01:21:41,119 At ChatGPT the people who did basically 1899 01:21:41,119 --> 01:21:44,539 this PP-- sorry, who did ChatGPT initially 1900 01:21:44,539 --> 01:21:47,332 are the ones who actually wrote PPO. 1901 01:21:47,332 --> 01:21:48,750 And I think they were just-- like, 1902 01:21:48,750 --> 01:21:50,880 there are a lot of reinforcement learning people. 1903 01:21:50,880 --> 01:21:54,319 And I think that for them it was very intuitive. 1904 01:21:54,319 --> 01:21:58,319 So there's also some additional potential benefits. 1905 01:21:58,319 --> 01:22:00,649 For example, I don't want to-- 1906 01:22:00,649 --> 01:22:03,011 yeah, for example, if you use the reward model, 1907 01:22:03,011 --> 01:22:04,969 the cool thing here with reinforcement learning 1908 01:22:04,970 --> 01:22:08,280 is that you can use unlabeled data with the reward model. 1909 01:22:08,279 --> 01:22:12,409 So here you can only use the labeled data for doing DPO-- 1910 01:22:12,409 --> 01:22:15,319 For PPO-- for PPO, you first train your reward model 1911 01:22:15,319 --> 01:22:18,079 and then you can use unlabeled data 1912 01:22:18,079 --> 01:22:19,670 where the reward model will basically 1913 01:22:19,670 --> 01:22:21,300 label this unlabeled data. 1914 01:22:21,300 --> 01:22:25,130 So this additional, kind of, potential-- 1915 01:22:25,130 --> 01:22:26,930 there could be potential improvements. 1916 01:22:26,930 --> 01:22:29,220 In practice it happens that there are none. 1917 01:22:29,220 --> 01:22:32,449 And I think just that a lot of people in this team 1918 01:22:32,449 --> 01:22:35,119 were reinforcement learning experts, including 1919 01:22:35,119 --> 01:22:39,059 the main author of PPO, John Schulman. 1920 01:22:39,060 --> 01:22:43,050 So much simpler than PPO, and it's basically performs as well. 1921 01:22:43,050 --> 01:22:46,180 So now this is the standard thing that people use. 1922 01:22:46,180 --> 01:22:47,980 At least in the open source community, 1923 01:22:47,979 --> 01:22:51,829 I believe it's actually the standard also in industry. 1924 01:22:51,829 --> 01:22:53,880 So that's called DPO. 1925 01:22:53,880 --> 01:22:57,690 Gains so those are all the papers on the left. 1926 01:22:57,689 --> 01:22:59,559 Here this is on the summarization task. 1927 01:22:59,560 --> 01:23:01,530 You see, all I want to show you is 1928 01:23:01,529 --> 01:23:04,590 that basically the pretrained models were OK 1929 01:23:04,590 --> 01:23:05,890 and they improve of scale. 1930 01:23:05,890 --> 01:23:07,360 If you do supervised fine tuning, 1931 01:23:07,359 --> 01:23:08,817 you improve them a little bit more, 1932 01:23:08,818 --> 01:23:12,369 if you do PPO or something with RLHF human feedback, 1933 01:23:12,369 --> 01:23:15,630 you get performance that are, oftentimes 1934 01:23:15,630 --> 01:23:18,640 depending on a benchmark, even better than humans. 1935 01:23:18,640 --> 01:23:21,360 So this is the human reference summaries. 1936 01:23:21,359 --> 01:23:22,059 Same thing. 1937 01:23:22,060 --> 01:23:25,260 This is on a paper that we have Alpaca farm where 1938 01:23:25,260 --> 01:23:27,820 we see the evaluation here is not too important 1939 01:23:27,819 --> 01:23:29,439 but basically see pretrained model. 1940 01:23:29,439 --> 01:23:33,519 You jump to SFT and then you jump to PPO, DPO and PPO, 1941 01:23:33,520 --> 01:23:36,570 DPO have the exact same performance. 1942 01:23:36,569 --> 01:23:38,799 So basically RLHF helps. 1943 01:23:38,800 --> 01:23:42,539 That's, kind of, the conclusion and DPO is simple. 1944 01:23:42,539 --> 01:23:43,560 Data. 1945 01:23:43,560 --> 01:23:46,950 The way that you collect that type of data. 1946 01:23:46,949 --> 01:23:51,029 First idea is just use humans as we already talked about. 1947 01:23:51,029 --> 01:23:53,159 Guidelines are very complicated for what 1948 01:23:53,159 --> 01:23:55,809 humans should be labeling, and it's really not that easy. 1949 01:23:55,810 --> 01:23:58,210 And actually, if you ever do some of the labeling, 1950 01:23:58,210 --> 01:24:01,480 you will see that it's extremely complicated. 1951 01:24:01,479 --> 01:24:03,869 Like if I Zoom in to this. 1952 01:24:03,869 --> 01:24:07,720 Here, I have a question tell me about self-driving cars. 1953 01:24:07,720 --> 01:24:09,210 And you read both self-driving cars 1954 01:24:09,210 --> 01:24:10,739 are vehicles that are capable of detecting 1955 01:24:10,739 --> 01:24:12,069 the surroundings, blah, blah blah, blah. 1956 01:24:12,069 --> 01:24:13,739 Self driving cars are cars that are equipped 1957 01:24:13,739 --> 01:24:15,539 with sensors, blah blah, blah to navigate 1958 01:24:15,539 --> 01:24:16,810 without the need for a driver. 1959 01:24:16,810 --> 01:24:18,250 I mean, both seem OK. 1960 01:24:18,250 --> 01:24:19,390 Which one is better? 1961 01:24:19,390 --> 01:24:21,810 It's actually hard to say at a glance. 1962 01:24:21,810 --> 01:24:24,480 And as a result, the problem with humans 1963 01:24:24,479 --> 01:24:27,209 is that you will start optimizing 1964 01:24:27,210 --> 01:24:28,659 a lot of high-level features. 1965 01:24:28,659 --> 01:24:30,309 For example, the second one is longer. 1966 01:24:30,310 --> 01:24:32,340 I can guarantee you that most humans will choose 1967 01:24:32,340 --> 01:24:34,520 the second one, even though I mean, 1968 01:24:34,520 --> 01:24:35,770 maybe the first one is better. 1969 01:24:35,770 --> 01:24:36,450 I don't know. 1970 01:24:36,449 --> 01:24:38,369 I haven't read it carefully. 1971 01:24:38,369 --> 01:24:39,769 So challenges of humans. 1972 01:24:39,770 --> 01:24:42,380 First, slow and expensive. 1973 01:24:42,380 --> 01:24:46,010 Second, as I just mentioned, it's hard to focus on things 1974 01:24:46,010 --> 01:24:47,400 that matter, like correctness. 1975 01:24:47,399 --> 01:24:49,579 And people usually look at things 1976 01:24:49,579 --> 01:24:53,479 that don't matter as much like the form, like length. 1977 01:24:53,479 --> 01:24:55,189 And as a result, so what I show here 1978 01:24:55,189 --> 01:24:58,309 is that when you do RLHF, the more you do RLHF, 1979 01:24:58,310 --> 01:25:01,380 the longer the output of the models become. 1980 01:25:01,380 --> 01:25:03,560 So if you've ever been annoyed at ChatGPT 1981 01:25:03,560 --> 01:25:05,430 answering you super long sentences, 1982 01:25:05,430 --> 01:25:08,020 this is because of RLHF. 1983 01:25:08,020 --> 01:25:11,240 Annotator distribution shift. 1984 01:25:11,239 --> 01:25:12,949 Like the distribution of annotators 1985 01:25:12,949 --> 01:25:15,679 that you use matters a lot, and you have to think, 1986 01:25:15,680 --> 01:25:17,960 like, what is even the humans that we want 1987 01:25:17,960 --> 01:25:20,060 to represent in these models? 1988 01:25:20,060 --> 01:25:22,730 Another question is crowdsourcing ethics. 1989 01:25:22,729 --> 01:25:25,099 Like usually these-- basically a lot 1990 01:25:25,100 --> 01:25:29,510 of the labeling that is done, the people who do them 1991 01:25:29,510 --> 01:25:31,250 are not paid well and they have to go 1992 01:25:31,250 --> 01:25:33,890 through a lot of toxic data because you basically 1993 01:25:33,890 --> 01:25:36,770 want the model to avoid saying the toxic data. 1994 01:25:36,770 --> 01:25:40,050 So crowdsourcing ethics too. 1995 01:25:40,050 --> 01:25:43,050 So many challenges with human data. 1996 01:25:43,050 --> 01:25:46,180 So what we did, also last year, is again, 1997 01:25:46,180 --> 01:25:48,840 the same thing as Alpaca, just the idea of like oh well, there 1998 01:25:48,840 --> 01:25:50,215 are challenges with humans, maybe 1999 01:25:50,215 --> 01:25:51,900 we can just replace them with LLMs. 2000 01:25:51,899 --> 01:25:55,769 So what we did is simply replace-- 2001 01:25:55,770 --> 01:25:56,783 I see that. 2002 01:25:56,783 --> 01:25:58,950 I'm just realizing that the slides are not centered. 2003 01:25:58,949 --> 01:26:02,739 Anyways you replace a human preference with preferences. 2004 01:26:02,739 --> 01:26:06,510 So here, on this figure, you see on the x-axis, the price 2005 01:26:06,510 --> 01:26:09,369 that we paid for collecting human data. 2006 01:26:09,369 --> 01:26:12,699 It's around $300 for 1,000 examples. 2007 01:26:12,699 --> 01:26:15,599 And this is on mechanical Turkers which are usually 2008 01:26:15,600 --> 01:26:19,770 like cheaper than maybe some of the other companies 2009 01:26:19,770 --> 01:26:20,860 that you could go through. 2010 01:26:20,859 --> 01:26:22,920 And on the y-axis, it's basically 2011 01:26:22,920 --> 01:26:27,069 the agreement with other humans, with the mode of other humans. 2012 01:26:27,069 --> 01:26:29,439 And what you see is that actually, as I told you before, 2013 01:26:29,439 --> 01:26:30,809 labeling is really complicated. 2014 01:26:30,810 --> 01:26:34,050 Humans agree with themselves only around 66% 2015 01:26:34,050 --> 01:26:36,039 of the time on a binary task. 2016 01:26:36,039 --> 01:26:38,019 And it's not that the humans are not good 2017 01:26:38,020 --> 01:26:41,380 here because we were five main authors on this paper. 2018 01:26:41,380 --> 01:26:43,989 We tried to label this data ourselves, 2019 01:26:43,989 --> 01:26:47,954 and we only had, like, 67 or 68% accuracy, even though we 2020 01:26:47,954 --> 01:26:50,079 talked-- like we talked for like three hours of how 2021 01:26:50,079 --> 01:26:51,449 we should be doing labeling. 2022 01:26:51,449 --> 01:26:52,729 But really, it's complicated. 2023 01:26:52,729 --> 01:26:54,159 It's not an easy task. 2024 01:26:54,159 --> 01:26:56,329 And here I just showed many different models. 2025 01:26:56,329 --> 01:26:59,289 And, basically, you see that models are much cheaper, 2026 01:26:59,289 --> 01:27:01,539 and they can actually get higher agreement 2027 01:27:01,539 --> 01:27:04,449 with the mode of humans than humans themselves. 2028 01:27:04,449 --> 01:27:06,949 And the reason why is because humans have a lot of variance, 2029 01:27:06,949 --> 01:27:08,000 models have no variance. 2030 01:27:08,000 --> 01:27:09,750 So there might be a little bit more biased 2031 01:27:09,750 --> 01:27:11,350 but have less variance. 2032 01:27:11,350 --> 01:27:13,280 So it works surprisingly well. 2033 01:27:13,279 --> 01:27:14,859 And now it's, kind of, the standard 2034 01:27:14,859 --> 01:27:16,729 in open source community. 2035 01:27:16,729 --> 01:27:18,939 I think even in industry a lot of people 2036 01:27:18,939 --> 01:27:21,729 use both humans and LLMs for improving 2037 01:27:21,729 --> 01:27:24,849 the collection of RLHF data. 2038 01:27:24,850 --> 01:27:27,220 And this is like-- this is the paper from last year, 2039 01:27:27,220 --> 01:27:30,880 but honestly, now it's more like the LLMs would be around this 2040 01:27:30,880 --> 01:27:32,600 agreement, and this costs around, 2041 01:27:32,600 --> 01:27:36,320 I would say 50 50x than humans and better agreement with human 2042 01:27:36,319 --> 01:27:39,019 than humans themselves. 2043 01:27:39,020 --> 01:27:39,720 OK. 2044 01:27:39,720 --> 01:27:45,225 So that gets us to evaluation of post training. 2045 01:27:45,225 --> 01:27:46,850 That goes back to your initial question 2046 01:27:46,850 --> 01:27:48,182 at the beginning of the lecture. 2047 01:27:48,182 --> 01:27:50,359 How do you evaluate something like ChatGPT? 2048 01:27:50,359 --> 01:27:54,420 The answers that GPT could give are basically unbounded. 2049 01:27:54,420 --> 01:27:56,460 And it's not that there's one right answer, 2050 01:27:56,460 --> 01:27:59,119 there are many answers that are just as good. 2051 01:27:59,119 --> 01:28:00,510 So there are many challenges. 2052 01:28:00,510 --> 01:28:03,380 One, you can't use validation loss 2053 01:28:03,380 --> 01:28:06,090 because one method might use PPO, 2054 01:28:06,090 --> 01:28:07,350 the other one might use DPO. 2055 01:28:07,350 --> 01:28:08,980 Validation loss is not comparable. 2056 01:28:08,979 --> 01:28:10,759 Second, you can't use-- 2057 01:28:10,760 --> 01:28:11,880 sorry, perplexity. 2058 01:28:11,880 --> 01:28:13,350 That's the thing I told you before. 2059 01:28:13,350 --> 01:28:16,020 These models are not calibrated. 2060 01:28:16,020 --> 01:28:17,550 They don't give distributions. 2061 01:28:17,550 --> 01:28:19,529 They just optimize for one thing. 2062 01:28:19,529 --> 01:28:22,639 So you can't use perplexity for actually evaluating these type 2063 01:28:22,640 --> 01:28:24,410 of models once they aligned-- 2064 01:28:24,409 --> 01:28:26,449 sorry, once they're aligned. 2065 01:28:26,449 --> 01:28:29,119 Third, there's a large diversity of questions 2066 01:28:29,119 --> 01:28:31,199 that humans might ask to these models. 2067 01:28:31,199 --> 01:28:35,090 Generation open QA some question answering some summarization 2068 01:28:35,090 --> 01:28:36,090 and all of these things. 2069 01:28:36,090 --> 01:28:38,869 So there's so many things you have to cover. 2070 01:28:38,869 --> 01:28:41,159 Then the tasks are really open ended, 2071 01:28:41,159 --> 01:28:42,569 so it's very hard to automate. 2072 01:28:42,569 --> 01:28:45,170 So that's what you were alluding to before. 2073 01:28:45,170 --> 01:28:48,199 So the idea is that instead of trying 2074 01:28:48,199 --> 01:28:51,889 to come up with really easily automated benchmarks, 2075 01:28:51,890 --> 01:28:55,100 it's just we're going to ask questions that users actually 2076 01:28:55,100 --> 01:28:56,850 ask to these models in practice. 2077 01:28:56,850 --> 01:28:58,520 And we're just going to ask annotators 2078 01:28:58,520 --> 01:29:01,740 to say between these two models, which one is better. 2079 01:29:01,739 --> 01:29:03,179 What's the better output. 2080 01:29:03,180 --> 01:29:04,909 So basically the exact same thing 2081 01:29:04,909 --> 01:29:08,930 as basically the data from RLHF but you 2082 01:29:08,930 --> 01:29:10,230 use it now for evaluation. 2083 01:29:10,229 --> 01:29:11,750 Yes I'm not sure I understand what 2084 01:29:11,750 --> 01:29:14,279 you mean by can't use perplexity not calibrated. 2085 01:29:14,279 --> 01:29:19,199 Like RLHF still doing like next token prediction. 2086 01:29:19,199 --> 01:29:19,722 So-- 2087 01:29:19,722 --> 01:29:21,140 Why can't perplexity be used then? 2088 01:29:21,140 --> 01:29:24,800 So think about the optimal solution 2089 01:29:24,800 --> 01:29:27,320 after doing PPL is basically one model that 2090 01:29:27,319 --> 01:29:30,931 gives you essentially a delta. 2091 01:29:30,931 --> 01:29:33,139 Like basically it says that there's only one sentence 2092 01:29:33,140 --> 01:29:34,430 that is-- 2093 01:29:34,430 --> 01:29:36,930 that could be generated for that question. 2094 01:29:36,930 --> 01:29:38,390 So now if you use it on something 2095 01:29:38,390 --> 01:29:40,920 that is slightly semantically differently different, 2096 01:29:40,920 --> 01:29:44,149 it would actually give a likelihood of 0 for that answer. 2097 01:29:44,149 --> 01:29:46,496 So in reality, it's not that extreme because as you say, 2098 01:29:46,497 --> 01:29:48,079 it's still a distribution, but it just 2099 01:29:48,079 --> 01:29:50,479 shows you that there's a fundamental issue 2100 01:29:50,479 --> 01:29:51,589 with perplexity. 2101 01:29:51,590 --> 01:29:55,020 Once these models are not LLMs anymore, 2102 01:29:55,020 --> 01:29:56,940 they were not trained, at least with PPO 2103 01:29:56,939 --> 01:29:59,219 they're not trained to do maximum likelihood anymore, 2104 01:29:59,220 --> 01:30:00,595 they were trained to be policies. 2105 01:30:04,360 --> 01:30:08,239 So probably the most common or the most-- 2106 01:30:08,239 --> 01:30:10,939 yeah, the most common benchmark or the most trusted one 2107 01:30:10,939 --> 01:30:14,719 is what we call ChatBotArena, which is basically 2108 01:30:14,720 --> 01:30:17,550 go on internet, have random users on the internet, 2109 01:30:17,550 --> 01:30:21,329 blindly talk with two chatbots, just ask many questions, 2110 01:30:21,329 --> 01:30:23,819 see the two answers and rate, which one is better. 2111 01:30:23,819 --> 01:30:26,869 And you do that over hundreds of thousands of users and then 2112 01:30:26,869 --> 01:30:30,920 you get the actual preferences and you get rankings of models. 2113 01:30:30,920 --> 01:30:33,470 So you can go right now on ChatBotArena 2114 01:30:33,470 --> 01:30:35,840 and actually interact with these models. 2115 01:30:35,840 --> 01:30:38,306 One potential issue just to highlight 2116 01:30:38,306 --> 01:30:40,639 is that while people who want to do these type of things 2117 01:30:40,640 --> 01:30:44,270 are usually more like tech-driven or like tech savvy. 2118 01:30:44,270 --> 01:30:46,100 So a lot of the questions that you will ask 2119 01:30:46,100 --> 01:30:47,840 are more like tech stuff discussing 2120 01:30:47,840 --> 01:30:50,300 software errors, inquiries about AI tools 2121 01:30:50,300 --> 01:30:52,579 and all of these things. 2122 01:30:52,579 --> 01:30:54,481 So another issue is cost and speed. 2123 01:30:54,481 --> 01:30:55,939 If you really want to use something 2124 01:30:55,939 --> 01:30:58,519 like this for development process, 2125 01:30:58,520 --> 01:31:01,490 it will be too costly because you will need to basically pay 2126 01:31:01,489 --> 01:31:03,659 a lot of humans to do that. 2127 01:31:03,659 --> 01:31:07,989 So one simple idea is, again, as we said many times, 2128 01:31:07,989 --> 01:31:10,380 just use LLM instead of humans. 2129 01:31:10,380 --> 01:31:13,109 You probably know the drill at this point. 2130 01:31:13,109 --> 01:31:15,779 Steps for every instruction generate outputs 2131 01:31:15,779 --> 01:31:19,409 by some baseline and the model that you want to evaluate. 2132 01:31:19,409 --> 01:31:22,439 So here you imagine that I'm comparing an answer 2133 01:31:22,439 --> 01:31:24,579 from ChatGPT and from Misrule. 2134 01:31:24,579 --> 01:31:29,350 I'm just asking a model, another model, which one is better. 2135 01:31:29,350 --> 01:31:32,200 And I just basically average that out. 2136 01:31:32,199 --> 01:31:32,699 Yeah. 2137 01:31:32,699 --> 01:31:34,569 I asked ChatGPT 4, which one is better. 2138 01:31:34,569 --> 01:31:37,229 I averaged that out over my entire distribution, 2139 01:31:37,229 --> 01:31:39,279 over my entire benchmark or data set, 2140 01:31:39,279 --> 01:31:41,259 and that gives me a win rate. 2141 01:31:41,260 --> 01:31:44,619 So a win probability for one model compared to another one. 2142 01:31:44,619 --> 01:31:46,750 And now you can rank models. 2143 01:31:46,750 --> 01:31:50,189 And this is the AlpacaEval leaderboard. 2144 01:31:50,189 --> 01:31:53,069 So the benefits of this is that actually we 2145 01:31:53,069 --> 01:31:56,019 show-- we get 98% correlation with ChatBotArena. 2146 01:31:56,020 --> 01:31:59,130 So very high correlation with humans. 2147 01:31:59,130 --> 01:32:01,710 So this is yeah, comparison with correlation 2148 01:32:01,710 --> 01:32:02,789 with other benchmarks. 2149 01:32:02,789 --> 01:32:05,180 And it takes less than three minutes and less than $10 2150 01:32:05,180 --> 01:32:05,680 to run. 2151 01:32:05,680 --> 01:32:06,940 So it's pretty cheap. 2152 01:32:06,939 --> 01:32:08,819 And there are downsides though. 2153 01:32:08,819 --> 01:32:11,489 One of them is poor correlation. 2154 01:32:11,489 --> 01:32:14,898 So as we already saw before, LLMs prefer, 2155 01:32:14,898 --> 01:32:16,690 this is one spurious correlation, not many. 2156 01:32:16,689 --> 01:32:17,731 I'll just talk about one. 2157 01:32:17,731 --> 01:32:19,060 LLMs prefer longer outputs. 2158 01:32:19,060 --> 01:32:21,010 Actually humans also prefer longer outputs. 2159 01:32:21,010 --> 01:32:23,340 But the problem or the issue once you use LLMs 2160 01:32:23,340 --> 01:32:26,250 is that once there is bias, you will continue optimizing that. 2161 01:32:26,250 --> 01:32:28,109 Humans at some point, I can guarantee you 2162 01:32:28,109 --> 01:32:29,902 if I ask a simple question, and you give me 2163 01:32:29,902 --> 01:32:31,510 five pages of answers, I'll be like, 2164 01:32:31,510 --> 01:32:32,560 no, I don't like that answer. 2165 01:32:32,560 --> 01:32:35,200 But LLMs if they have this bias and they were trained for that, 2166 01:32:35,199 --> 01:32:37,529 they will continue preferring longer outputs. 2167 01:32:37,529 --> 01:32:42,869 So here we see the preference just showing 2168 01:32:42,869 --> 01:32:46,229 that humans and models prefer longer outputs. 2169 01:32:46,229 --> 01:32:50,250 And here is another view of the initial AlpacaEval data set 2170 01:32:50,250 --> 01:32:53,399 benchmark, where when we asked-- 2171 01:32:53,399 --> 01:32:56,939 when we rank GPT4, when we look at the win rate of GPT4 2172 01:32:56,939 --> 01:33:01,689 versus actually GPT4 itself, if we use the standard GPT4, 2173 01:33:01,689 --> 01:33:03,779 it gets 50%, kind of, by definition because we're 2174 01:33:03,779 --> 01:33:06,069 comparing GPT4 versus GPT4. 2175 01:33:06,069 --> 01:33:09,250 But if we ask a GPT4 to be slightly more verbose, 2176 01:33:09,250 --> 01:33:12,189 so we just say in the prompt, be verbose in your answers, 2177 01:33:12,189 --> 01:33:15,009 then it gets a win rate of 64.4%. 2178 01:33:15,010 --> 01:33:16,573 So really there's a huge variance. 2179 01:33:16,573 --> 01:33:17,990 And if we ask it to be concise, it 2180 01:33:17,989 --> 01:33:20,130 gets 20% so there's a huge variance 2181 01:33:20,130 --> 01:33:24,310 depending on whether you ask it to be concise or verbose. 2182 01:33:24,310 --> 01:33:25,890 That's very annoying. 2183 01:33:25,890 --> 01:33:29,260 So one possible solution, which is what we did, 2184 01:33:29,260 --> 01:33:31,545 is just use some regression analysis. 2185 01:33:31,545 --> 01:33:32,920 I'm not going to go into details, 2186 01:33:32,920 --> 01:33:34,337 but basically use causal inference 2187 01:33:34,337 --> 01:33:36,040 tools to control for length. 2188 01:33:36,039 --> 01:33:38,890 And right now actually length matters much less. 2189 01:33:38,890 --> 01:33:41,710 So if you ask it to be verbose, you still get some gains, 2190 01:33:41,710 --> 01:33:44,430 but much less. 2191 01:33:44,430 --> 01:33:44,930 Great. 2192 01:33:44,930 --> 01:33:46,740 So that's all about post training. 2193 01:33:46,739 --> 01:33:48,739 And now for the next eight minutes, 2194 01:33:48,739 --> 01:33:51,260 I might talk about systems or just answer questions. 2195 01:33:51,260 --> 01:33:52,130 Yes. 2196 01:33:52,130 --> 01:33:56,289 Can you go back to your post training, internal post 2197 01:33:56,289 --> 01:33:57,460 training. 2198 01:33:57,460 --> 01:33:59,980 How did we tune those parameters using 2199 01:33:59,979 --> 01:34:03,339 the small body of fine-tuning data 2200 01:34:03,340 --> 01:34:05,360 and have such big effect on the model? 2201 01:34:05,359 --> 01:34:07,449 You mentioned earlier that there's a different set 2202 01:34:07,449 --> 01:34:08,880 of hyperparameters. 2203 01:34:08,880 --> 01:34:11,590 Are we changing just some of the weights, the later weights 2204 01:34:11,590 --> 01:34:12,630 or other weights. 2205 01:34:12,630 --> 01:34:13,880 What's actually happening? 2206 01:34:13,880 --> 01:34:14,529 Yeah. 2207 01:34:14,529 --> 01:34:16,579 Yeah, I, kind of, skimmed through all of this. 2208 01:34:16,579 --> 01:34:17,750 You change all the weights. 2209 01:34:17,750 --> 01:34:20,529 Actually, industry will change all the weights. 2210 01:34:20,529 --> 01:34:22,630 In open source land, you might have 2211 01:34:22,630 --> 01:34:26,739 heard of Laura, which is going to change basically only 2212 01:34:26,739 --> 01:34:29,630 some of the weights or it actually, to be more specific, 2213 01:34:29,630 --> 01:34:31,180 it's going to add some differences 2214 01:34:31,180 --> 01:34:33,200 to the output of every layer. 2215 01:34:33,199 --> 01:34:37,742 But in industry, you're going to just fine tune all the weights. 2216 01:34:37,742 --> 01:34:40,850 And also to say something else about the data, actually, 2217 01:34:40,850 --> 01:34:42,670 this last step, RLHF you usually going 2218 01:34:42,670 --> 01:34:45,619 to collect a lot more data than with SFT. 2219 01:34:45,619 --> 01:34:50,755 So if FSFT is like 5,000, 10,000, maybe 50,000 with, 2220 01:34:50,755 --> 01:34:54,340 RLHF I think you're going to be more around like the one million 2221 01:34:54,340 --> 01:34:55,390 order of magnitude. 2222 01:34:55,390 --> 01:34:57,380 It's still much less than pretraining though. 2223 01:34:57,380 --> 01:34:57,880 Yeah. 2224 01:34:57,880 --> 01:35:00,230 Because pretraining is 15 trillion tokens. 2225 01:35:00,229 --> 01:35:02,454 I mean, this is like-- that's not even a drop 2226 01:35:02,454 --> 01:35:05,010 and yet you influence the weight a lot. 2227 01:35:05,010 --> 01:35:05,989 So because you do it-- 2228 01:35:05,989 --> 01:35:10,398 I mean, you have to think that how you do it is you use-- 2229 01:35:10,398 --> 01:35:12,940 I mean, as I said, the learning rate that you're going to use 2230 01:35:12,939 --> 01:35:16,189 is going to be different, but also you only do that. 2231 01:35:16,189 --> 01:35:18,009 So just imagine if I trained-- 2232 01:35:18,010 --> 01:35:19,909 even if I trained on one sentence, 2233 01:35:19,909 --> 01:35:22,689 but over and over again at some point 2234 01:35:22,689 --> 01:35:24,429 my model will only generate that sentence 2235 01:35:24,430 --> 01:35:27,730 even if it was just one sentence instead of 2236 01:35:27,729 --> 01:35:29,029 the 15 trillion tokens. 2237 01:35:29,029 --> 01:35:30,880 So if you use a large enough learning 2238 01:35:30,880 --> 01:35:33,730 rate and for enough time, you will basically 2239 01:35:33,729 --> 01:35:35,059 overfit that sentence. 2240 01:35:35,060 --> 01:35:39,770 So the key thing to remember is that the data is not-- 2241 01:35:39,770 --> 01:35:42,530 it's not as if you mix some post-training data 2242 01:35:42,529 --> 01:35:43,819 and some pretraining data. 2243 01:35:43,819 --> 01:35:47,389 You do pretraining, and then you just start fine-tuning only 2244 01:35:47,390 --> 01:35:48,270 on the post-training. 2245 01:35:48,270 --> 01:35:50,330 So another way, maybe another perspective 2246 01:35:50,329 --> 01:35:53,269 is that the pretraining is just the initialization 2247 01:35:53,270 --> 01:35:54,120 of your model. 2248 01:35:54,119 --> 01:35:56,239 And once you view it that way, that this is just 2249 01:35:56,239 --> 01:35:59,524 initialization of weights, then there's nothing special. 2250 01:35:59,524 --> 01:36:02,149 Like you don't need to remember that you train on a lot of data 2251 01:36:02,149 --> 01:36:02,759 before. 2252 01:36:02,760 --> 01:36:04,909 The only thing that matters is that you had an initialization 2253 01:36:04,909 --> 01:36:06,438 and now I actually train the model. 2254 01:36:06,438 --> 01:36:07,980 So maybe you think about it that way. 2255 01:36:07,979 --> 01:36:10,289 Like this is a Markov property in some ways. 2256 01:36:10,289 --> 01:36:11,789 It's just like you had your weights. 2257 01:36:11,789 --> 01:36:12,890 This is my initialization. 2258 01:36:12,890 --> 01:36:14,510 Now I'm training that one. 2259 01:36:14,510 --> 01:36:16,110 Does that answer your question? 2260 01:36:16,109 --> 01:36:20,779 Kind of but you said something just now about it's 2261 01:36:20,779 --> 01:36:23,929 almost the equivalent of just rerunning the fine tuning 2262 01:36:23,930 --> 01:36:25,250 data many times. 2263 01:36:25,250 --> 01:36:28,069 Is it actually-- is that what actually happens in order 2264 01:36:28,069 --> 01:36:30,719 to give so much more preference? 2265 01:36:33,500 --> 01:36:37,010 You might-- I actually don't know right now how they do it 2266 01:36:37,010 --> 01:36:37,800 in industry. 2267 01:36:37,800 --> 01:36:40,199 When we did our packet, we had to do three epochs. 2268 01:36:40,199 --> 01:36:44,569 So you did run it three times through it. 2269 01:36:44,569 --> 01:36:46,460 But I mean, even the number of times 2270 01:36:46,460 --> 01:36:48,720 that you run it through, it's actually not important. 2271 01:36:48,720 --> 01:36:52,610 The only thing-- the only thing is the effective learning rate 2272 01:36:52,609 --> 01:36:54,979 that what matters. 2273 01:36:54,979 --> 01:36:56,939 So yeah. 2274 01:36:56,939 --> 01:36:58,349 Great. 2275 01:36:58,350 --> 01:37:00,789 So I think I have five minutes. 2276 01:37:06,153 --> 01:37:12,119 OK I might try to give a high-level overview at least 2277 01:37:12,119 --> 01:37:14,489 from one of the systems trick. 2278 01:37:14,489 --> 01:37:19,199 Systems, as we said, for everyone bottleneck is-- 2279 01:37:19,199 --> 01:37:21,510 sorry compute is the huge bottleneck. 2280 01:37:21,510 --> 01:37:24,869 One question you might ask is, why not buy more GPUs? 2281 01:37:24,869 --> 01:37:26,890 GPUs are expensive, but also are scarce. 2282 01:37:26,890 --> 01:37:28,600 Even if you have $10 million right now, 2283 01:37:28,600 --> 01:37:31,230 you cannot buy the best GPUs. 2284 01:37:31,229 --> 01:37:33,579 [INAUDIBLE] 2285 01:37:33,579 --> 01:37:35,529 There's also some physical limitations. 2286 01:37:35,529 --> 01:37:37,769 When you have multiple GPUs, you have 2287 01:37:37,770 --> 01:37:39,070 to communicate between them. 2288 01:37:39,069 --> 01:37:40,529 That takes time. 2289 01:37:40,529 --> 01:37:43,679 So just buying more GPUs is not that easy. 2290 01:37:43,680 --> 01:37:45,342 So it's really important to think about 2291 01:37:45,341 --> 01:37:47,549 how do you allocate resources and how do you optimize 2292 01:37:47,550 --> 01:37:49,230 your pipeline, so system? 2293 01:37:49,229 --> 01:37:53,109 101 on GPUs, I'm sorry, I'm going slightly faster. 2294 01:37:53,109 --> 01:37:55,799 I hope that some of you at least can follow. 2295 01:37:55,800 --> 01:37:58,190 GPUs are basically optimized for throughput. 2296 01:37:58,189 --> 01:38:01,449 CPUs are optimized for latency. 2297 01:38:01,449 --> 01:38:03,609 So GPUs, the way you have to think about it 2298 01:38:03,609 --> 01:38:04,750 is that there's one-- 2299 01:38:04,750 --> 01:38:07,840 there's one command that is run on many, many cores 2300 01:38:07,840 --> 01:38:11,170 at the same time on different type of data. 2301 01:38:11,170 --> 01:38:13,244 So this is how you see a GPU. 2302 01:38:13,244 --> 01:38:14,869 You see there are many different codes. 2303 01:38:14,869 --> 01:38:17,539 We call them streaming multiprocessors, 2304 01:38:17,539 --> 01:38:20,359 which is very different than the usual CPU architecture. 2305 01:38:20,359 --> 01:38:24,939 So just think high throughput parallelization for GPUs. 2306 01:38:24,939 --> 01:38:27,710 GPUs are optimized for fast matrix multiplication. 2307 01:38:27,710 --> 01:38:30,859 So every time you will do-- you will do something on GPU. 2308 01:38:30,859 --> 01:38:33,589 If you can do it with a matrix multiplication, 2309 01:38:33,590 --> 01:38:36,502 it's going to be 10 times faster than with anything else. 2310 01:38:36,502 --> 01:38:38,170 That is a little bit annoying because it 2311 01:38:38,170 --> 01:38:40,779 means that we are, kind of, bottlenecked 2312 01:38:40,779 --> 01:38:44,289 to doing anything with matrix multiplications. 2313 01:38:44,289 --> 01:38:46,359 Another thing to note with GPUs is 2314 01:38:46,359 --> 01:38:48,579 that compute has been improving faster 2315 01:38:48,579 --> 01:38:50,359 than memory and communication. 2316 01:38:50,359 --> 01:38:55,750 So right now GPUs usually are hard to keep-- 2317 01:38:55,750 --> 01:38:58,569 Like the data that you sent to GPUs 2318 01:38:58,569 --> 01:39:00,799 is actually hard to keep up with the processes. 2319 01:39:00,800 --> 01:39:02,260 So most of your GPUs are actually 2320 01:39:02,260 --> 01:39:04,869 going to be idle if you just run normal code, 2321 01:39:04,869 --> 01:39:06,349 if you don't optimize your code. 2322 01:39:06,350 --> 01:39:10,870 So communication-- and this will continue over time. 2323 01:39:10,869 --> 01:39:12,970 Another thing to know about GPUs is that there's 2324 01:39:12,970 --> 01:39:13,810 a memory hierarchy. 2325 01:39:13,810 --> 01:39:15,560 This is the same thing actually with CPUs, 2326 01:39:15,560 --> 01:39:17,870 but basically the closer you are to your cores, 2327 01:39:17,869 --> 01:39:20,659 the less memory there is, but the faster things run. 2328 01:39:20,659 --> 01:39:24,847 If you are further, more memory slower. 2329 01:39:24,847 --> 01:39:26,140 Oh yeah I'm going to skip that. 2330 01:39:26,140 --> 01:39:27,940 OK actually, I'm going to say it. 2331 01:39:27,939 --> 01:39:29,329 I told you about this-- 2332 01:39:29,329 --> 01:39:31,149 the fact of communication. 2333 01:39:31,149 --> 01:39:32,769 The metric that people usually look at 2334 01:39:32,770 --> 01:39:34,490 is model FLOP utilization. 2335 01:39:34,489 --> 01:39:37,689 So what is the theoretical maximum that GPU could run at, 2336 01:39:37,689 --> 01:39:39,879 number of flops that you could use per second-- 2337 01:39:39,880 --> 01:39:42,730 divide-- sorry, the number of observed throughput 2338 01:39:42,729 --> 01:39:45,949 divided by this theoretical maximum. 2339 01:39:45,949 --> 01:39:49,399 And in general, if you reach 50% you're very happy. 2340 01:39:49,399 --> 01:39:51,789 Like Facebook I looked at llama was at 45 2341 01:39:51,789 --> 01:39:52,760 or something like this. 2342 01:39:52,760 --> 01:39:55,960 So that means that data doesn't come fast enough 2343 01:39:55,960 --> 01:39:58,779 even for these big companies. 2344 01:39:58,779 --> 01:40:00,746 So one simple trick, and that might 2345 01:40:00,747 --> 01:40:02,579 be the only one I'm going to tell you about, 2346 01:40:02,579 --> 01:40:04,140 is low precision. 2347 01:40:04,140 --> 01:40:06,869 One simple idea is that well, if I'm 2348 01:40:06,869 --> 01:40:09,251 going to put my floats in low precision, 2349 01:40:09,252 --> 01:40:10,710 then there's going to be fewer bits 2350 01:40:10,710 --> 01:40:12,430 that I have to send to my GPUs. 2351 01:40:12,430 --> 01:40:14,710 If there's fewer bits, it's faster communication, 2352 01:40:14,710 --> 01:40:16,029 lower memory consumption. 2353 01:40:16,029 --> 01:40:17,699 Things are going to go faster. 2354 01:40:17,699 --> 01:40:19,529 And for deep learning it just happens 2355 01:40:19,529 --> 01:40:22,800 that decimal is not that important. 2356 01:40:22,800 --> 01:40:25,739 So when you do matrix multiplication, when 2357 01:40:25,739 --> 01:40:28,380 you do like for example, SGD, there's already so much noise 2358 01:40:28,380 --> 01:40:33,840 that if you update something by 0.01 or 0.015, who cares. 2359 01:40:33,840 --> 01:40:37,949 So basically instead of using 32 bits per float, which 2360 01:40:37,949 --> 01:40:41,460 is what people used to use, or 64 for example, which 2361 01:40:41,460 --> 01:40:43,659 is what you would use in other domains, 2362 01:40:43,659 --> 01:40:46,420 you use 16 bits for matrix multiplication. 2363 01:40:46,420 --> 01:40:49,550 So for every float you use 16 bits. 2364 01:40:49,550 --> 01:40:51,270 And for training you have this type 2365 01:40:51,270 --> 01:40:54,160 of what we call automatic mixed precision. 2366 01:40:54,159 --> 01:40:57,220 Which is that some of the things are in 32 bits, 2367 01:40:57,220 --> 01:40:58,720 others are in 60 bit-- 2368 01:40:58,720 --> 01:41:00,122 on 16 bits. 2369 01:41:00,122 --> 01:41:02,079 Generally, the way you should be thinking about 2370 01:41:02,079 --> 01:41:05,029 it is that your weights are stored-- of your model, 2371 01:41:05,029 --> 01:41:06,969 are stored in 32 bits. 2372 01:41:06,970 --> 01:41:10,510 But just before the computation you put everything in 16 bits. 2373 01:41:10,510 --> 01:41:12,400 Like this you do computation super fast. 2374 01:41:12,399 --> 01:41:16,369 And at the end you update your weights in 32 bits. 2375 01:41:16,369 --> 01:41:19,090 And the reason why you do all the updates in 32 bits is just 2376 01:41:19,090 --> 01:41:21,007 think that if your learning rate, for example, 2377 01:41:21,006 --> 01:41:23,409 is very small, you still want to be able to make 2378 01:41:23,409 --> 01:41:25,090 a difference in your weights. 2379 01:41:25,090 --> 01:41:28,310 So all the computation is done in 16 bits, 2380 01:41:28,310 --> 01:41:30,830 but the weights are actually stored in 32 bits. 2381 01:41:30,829 --> 01:41:35,109 So that's like the standard way that people are doing it. 2382 01:41:35,109 --> 01:41:36,849 OK, I'll actually talk just about this, 2383 01:41:36,850 --> 01:41:39,010 and then I'll skip all the rest, operator fusion, because I think 2384 01:41:39,010 --> 01:41:40,270 this is actually pretty cool. 2385 01:41:40,270 --> 01:41:42,730 As I just said, communication is very slow 2386 01:41:42,729 --> 01:41:45,889 and actually every time you use a PyTorch line, 2387 01:41:45,890 --> 01:41:49,039 it basically moves variable to global memory of your GPU. 2388 01:41:49,039 --> 01:41:54,369 So when you have something like this x dot cosine equal x1, 2389 01:41:54,369 --> 01:41:56,309 and then you do x1 dot cosine. 2390 01:41:56,310 --> 01:41:58,140 What is happening behind the scenes 2391 01:41:58,140 --> 01:42:00,070 is that you take the x, which is data. 2392 01:42:00,069 --> 01:42:03,949 You ship it to your actual processors of your GPUs. 2393 01:42:03,949 --> 01:42:05,130 You apply the cosine. 2394 01:42:05,130 --> 01:42:07,500 You ship it back to the main memory of your GPU 2395 01:42:07,500 --> 01:42:09,340 and then you see the next line. 2396 01:42:09,340 --> 01:42:12,510 You ship it back to the computer-- to the GPU processor, 2397 01:42:12,510 --> 01:42:15,600 you apply another cosine and you ship it back again. 2398 01:42:15,600 --> 01:42:17,579 So another way to see that is that you 2399 01:42:17,579 --> 01:42:20,729 go from your DRAM, which is your global memory and your GPU 2400 01:42:20,729 --> 01:42:22,419 and you ship it to compute. 2401 01:42:22,420 --> 01:42:24,109 You ship it back for every line. 2402 01:42:24,109 --> 01:42:25,799 This is a naive way of doing it. 2403 01:42:25,800 --> 01:42:28,079 This seems very wasteful. 2404 01:42:28,079 --> 01:42:31,769 So the idea, simple idea of operator fusion 2405 01:42:31,770 --> 01:42:35,850 is just communicate, do all the computation, ship it back once. 2406 01:42:35,850 --> 01:42:39,390 And this is exactly what fused kernels are. 2407 01:42:39,390 --> 01:42:44,100 So if you ever want to make your compute-- your computations 2408 01:42:44,100 --> 01:42:46,950 in PyTorch much faster, just apply torch dot 2409 01:42:46,949 --> 01:42:48,909 compile on your model. 2410 01:42:48,909 --> 01:42:51,970 This is going to make your model around 2 times faster. 2411 01:42:51,970 --> 01:42:56,260 And what it does is simply that it rewrites your code-- 2412 01:42:56,260 --> 01:43:03,119 your PyTorch code basically in C++ in CUDA to do 2413 01:43:03,119 --> 01:43:05,529 the communication only once then do all the operations, 2414 01:43:05,529 --> 01:43:07,800 then ship it back. 2415 01:43:07,800 --> 01:43:10,390 OK I'm not going to have time to talk about tiling. 2416 01:43:10,390 --> 01:43:11,670 Tiling is important. 2417 01:43:11,670 --> 01:43:12,600 Parallelization. 2418 01:43:12,600 --> 01:43:15,420 Parallelization is important. 2419 01:43:15,420 --> 01:43:17,149 And mixture of experts. 2420 01:43:17,149 --> 01:43:18,809 Mixture of experts is important. 2421 01:43:18,810 --> 01:43:19,780 Outlook. 2422 01:43:19,779 --> 01:43:23,099 There are many things we haven't talked about. 2423 01:43:23,100 --> 01:43:25,350 We haven't talked about architectures we definitely 2424 01:43:25,350 --> 01:43:27,480 haven't talked about inference. 2425 01:43:27,479 --> 01:43:29,859 There are many other things that are important with LLMs. 2426 01:43:29,859 --> 01:43:31,359 What is the UI that you use? 2427 01:43:31,359 --> 01:43:34,289 I mean, arguably ChatGPT, the big novelty was just 2428 01:43:34,289 --> 01:43:35,789 have a simple UI to use it. 2429 01:43:35,789 --> 01:43:36,930 Multi-modality. 2430 01:43:36,930 --> 01:43:38,820 What are all the misuses you could have. 2431 01:43:38,819 --> 01:43:41,319 The fact that there might not be enough data on the internet 2432 01:43:41,319 --> 01:43:42,420 to train all these models. 2433 01:43:42,420 --> 01:43:45,050 Legality of data collection, so many other things. 2434 01:43:45,050 --> 01:43:47,699 If you are interested in all these topics, 2435 01:43:47,699 --> 01:43:49,479 I would suggest three classes. 2436 01:43:49,479 --> 01:43:54,809 CS224N is probably the one that touches the least on LLMs, 2437 01:43:54,810 --> 01:43:57,840 but it gives some background and historical context 2438 01:43:57,840 --> 01:44:01,510 of all the LLMs and gives some adjacent material. 2439 01:44:01,510 --> 01:44:04,920 CS324 I think it's called-- 2440 01:44:04,920 --> 01:44:07,619 I think it's just called Large Language Models, more 2441 01:44:07,619 --> 01:44:10,300 in depth reading and lectures on everything I talked about. 2442 01:44:10,300 --> 01:44:13,930 CS336 which is large language model from scratch, 2443 01:44:13,930 --> 01:44:16,680 you actually build your own LLM. 2444 01:44:16,680 --> 01:44:20,530 It's an amazing class also given by my two supervisors. 2445 01:44:20,529 --> 01:44:23,759 Very heavy workload, so be careful. 2446 01:44:23,760 --> 01:44:25,310 Great.