10/12-10/23

The last two weeks in the OpenAI Scholars program have been really great. I've met a lot of really cool people and have been learning a ton.

I decided that to start out, I wanted to get experience with using pytorch to implement various things, roughly from scratch. The fast.ai course of 2019 (especially part two) was really helpful for showing me how to do this, but doing it myself has been really useful in drilling in the knowledge of how things work.

Along the way, I've learned a few interesting things.

Standard deviation is biased for small sample sizes.

The intuitive explanation for this is that it's calculated by subtracting the mean from your samples, squaring all the values, then summing the squares. When you take the mean of your samples, it's likely to be closer to your samples than the true mean (since it's an average of your samples). Thus, the differences are likely to be smaller, and your standard deviation will be an underestimate.

One interesting point here is that if you compute standard deviation over samples in a batch, do this lots of times, and average all of those standard deviations, it doesn't help. Because it's biased, averaging over biased samples will still be biased. Thus, the only solution is either to use a corrective term to account for this issue, or use a bigger batch size.

How serious of an issue is this? The values on this page will give you a good idea. While it varies for distribution, for a normal distribution, with standard deviation 1.0, you can read the "sample size" column as batch size, and "numerical value" column as what standard deviation you will measure. So you can see, with batch size 3 or 4, the bias results in measuring a standard deviation of ~0.9, and with a batch size of 2, it's 0.79! You can also see the issue is not nearly as bad once batch size gets larger, and I did some simple tests and found that using those correction terms works for small batch sizes for things like batch norm and fixup initialization fairly well (though, again, a better solution is just to use bigger batch sizes when that's an option).

By the way, dividing by (n-1) instead of n (which is actually the default behaviour in pytorch std, it's what the unbiased parameter refers to: unbiased will divide by n-1) helps a little, but doesn't completely solve this issue, whereas corrective terms sort of will. I say "sort of" because the caveat from that article of "it depends on your distribution" seems to really matter. For example, RELU seemed to require slightly different corrective terms, and I imagine using real data instead of normal distributions would have slightly different distributional properties as well.

Transformers and self attention are much simpler than I thought.

The idea is quite nice: We want to model a database. At a high conceptual level, how databases work is they have values you store. You use keys to store those values. When looking up something, you say "how much does my query match these keys" and return the value stored in the bin with the key that matches your value closest (I know this isn't actually how they work, but it's a decent high level model).

Attention just does a fuzzy version of this, where we take a weighted sum of all values in the database, with each one multiplied by how closely our query matches our key. Specifically:

For each word in a sentence, make an embedding (that just means, have a giant weight matrix with one row for every word in your vocabulary. Convert a word into a vector by looking up it's corresponding row). Append to this embedding a positional embedding, which encodes where it is in the sentence. There are lots of creative ways people have done positional embeddings, with different tradeoffs.

Take those embedded vectors and call them x_i. Now compute three values:

query vector: q_i = Q*x_i

key vector: k_i = K*x_i

value: v_i = V*x_i

For a given word at position i, we want to "look up" in our database the other words to see which ones might be useful to look at for predicting the word at position i+1. To do this, we make a score value

s_i,j = q_i dot k_j

Which captures how "closely" our query for word i corresponds to the key word j is stored at.

Now that we have those scores, we can use softmax* to adjust our scores so they are values from 0-1 that sum to 1. We should also make sure s_i,j is 0 for any j>i, since it's really easy to predict the next word if you can just look at what is there. Call these adjusted scores p_i,j.

Now we just do a weighted sum over all the values:

u_i = sum over j of (p_i,j*v_j)

Finally, if our value size is different from our original x_i size, we need to use another matrix to project it back to the right size. This gets us a new "embedded vector" for each word that has been processed by our first attention layer. We can feed each of these vectors into another attention layer, and another, etc. Finally, we take the output and use it to predict a distribution over possible next words.

There are some small tweaks (adding a residual connection, dense layers in between, layer norm, doing multiple heads in parallel, etc.), but this is essentially how GPT type models work. Transformers are slightly more complicated because they have encoders and decoders, and I imagine you could do even weirder graph type structures of these attention pieces if you were feeling creative.

One fun point about this is that while the time complexity is quadratic in terms of number of words (because you need to consider every pair), you can actually keep linear space complexity in terms of number of words by doing this computation one word at a time, instead of doing them in a huge matrix all simultaneously.

*In the fast.ai course, Jeremy warned us that whenever we use softmax, question it. Softmax makes sense for when you are classifying things and your data is only ever in a single class, and can't be in multiple classes or no classes. If no classes or multiple classes is possible in your dataset, you should use sigmoid instead. In this case, it seems like it's being used to ensure that we are doing a weighted sum, which is probably important to ensure magnitudes don't vary too much, but I'm curious about looking into this more.

Finally, I've been thinking about research directions. 

There are two general directions I've been thinking about right now.

The first question is essentially around this question of if we can break down the kinds of things large language models are learning into concrete tasks. The high level idea is to understand at different stages what these models are learning, then turn those learning aspects into synthetic tasks. This would serve a few goals:

1. Synthetic tasks ideally would be open-ended and could be scaled up to attain unbounded superhuman performance, whereas learning from human data sort of limits you to human level performance, or sometimes slightly better. Theorem proving and grammar induction are two examples here. Relatedly, the synthetic tasks can be studied and carefully optimized to provide fast training, whereas human data is often pretty redundant (Data distillation techniques I still find super promising though)

2. By taking lots of checkpoints of the model as it's training on general text (or theorem proving), then studying how quickly those checkpoints transfer to the synthetic tasks, it gives us a way of saying when it learns those skills. Similarly, seeing how well a model trained on synthetic tasks can transfer to the real world tasks helps you understand how much that synthetic task is "embedded" in that real world task. Going at this from both angles helps you decompose your improvements in loss into an understanding of the essential tasks that make up the task your want your model to do.

3. For synthetic tasks, it's much easier to measure the generalization vs memorization continuum (when it becomes easier for a model to learn the general task instead of memorizing specific cases or simple heuristics). See for example this paper.

4. This can help to give us a better understanding of human data, which might let us study theoretical questions like "what would superhuman performance at X" look like in a relatively safe way. The long term hope is that if we could build a fairly intelligent trained purely on synthetic data, ensuring it is safe is likely more plausible because we can utilize theoretical components of our data in our analysis of what kind of behaviours the model might have. Also, it lets us decouple world knowledge from other capabilities, meaning the model has more work to do to do harmful things because initially it may have no world knowledge unless we specifically give it that.

However, I'm not that certain if the above approach is plausible. In image tasks, there is a discussion of the "simulator reality gap" for things that look very like the real world data. Getting synthetic data that actually manages to transfer to real world data is really hard, and even less studied in language, so this is probably really difficult. I imagine much of the work is going to be searching for and constructing synthetic tasks, finding places they fail, and iterating. Still, I think there are many interesting generalization related questions (point 3) that could be studied here as a fallback if the idea isn't plausible as a whole.

The second question has to do with thinking about problems that we currently see with ML systems in the real world. A few people have expressed the concern that "if we can't even make sure our current systems are doing what we want, why do we have any hope that later systems will do what we want?". It's sort of an argument against scaling up that's hard to argue against, though I recognize there are some reasonable counterpoints.

There's lots of interesting questions in that area, I'm particularly interested in feedback loops that happen in simulated social network type settings and recommendation systems. I'm aware that large companies think hard about these problems, and don't expect to get anywhere near the level of sophisticated answers they probably have, but I still think open conversation is important. Some relevant reading:

The AI Economist

Popularity Signals in Trial-Offer Markets with Social Influence and Position Bias

Aligning Popularity and Quality in Online Cultural Markets

The general idea I've been thinking about is using RL agents to model human behaviour in very simple settings that sort of approximate reality. The point is that sometimes if you aren't careful, recommendation systems (even if they are theoretically optimal at minimizing their loss!) will cause feedback loops that end up not getting the outcome you want. Polya Urns I think are a very understudied model for theoretically addressing this phenomenon, and I'm not aware of too much work using them in these contexts (though in talking with others, sometimes people come up with and use the fixes recommended by urns without needing the theory of poly urns to find those fixes). But my thoughts on this topic is still very bare bones, and I have quite a bit of reading to do.

Comments

Popular posts from this blog

My prediction of AGI in 2021-2023

OpenAI Scholars Project

Inductive Linguistic Biases