Understanding nested list comprehensions in Python

In my last blog post, I discussed list comprehensions, and how to think about them. Several people suggested (via e-mail, and in comments on the blog) that I should write a follow-up posting about nested list comprehensions.

I must admit that nested list comprehensions are something that I’ve shied away from for years. Every time I’ve tried to understand them, let alone teach them, I’ve found myself stumbling for words, without being clear about what was happening, what the syntax is, or where I would want to use them. I managed to use them on a few occasions, but only after a great deal of trial and error, and without really understanding what I was doing.

Fortunately, the requests that I received, asking how to work with such nested list comprehensions, forced me to get over my worries. I’ve figured out what’s going on, and even think that I understand what my problem was with understanding them before.

Get the bonus content: Nested list comprehensions

The key thing to remember is that in a list comprehension, we’re dealing with an iterable. So when I say:

[ len(line) 
for line in open('/etc/passwd') ]

I’m saying that I want to iterate over the file object we got from opening /etc/passwd. There will be one element in the output list for each element in the input iterable — aka, every line in the file.

That’s great if I want my list comprehension to return something based on each line of /etc/passwd. But each line of /etc/passwd is a string, and thus also iterable. Maybe I want to return something not based on the lines of the file, but on the characters of each line.

Were I to use a “for” loop to process the file, I would use a nested loop — i.e., one loop inside of the other, with the outer loop iterating over lines and the inner loop iterating over consonants. It turns out that we can use a nested list comprehension, too. Here’s a simple example of a nested list comprehension:

[(x,y) for x in range(5) for y in range(5)]

If your reaction to this is, “What in the blazes does that mean?!?” then you’re not alone. Until just recently, that’s what I thought, too.

However: If we rewrite the above nested list comprehension using my preferred (i.e., multi-line) list-comprehension style, I think that things become a bit clearer:

 [(x,y)  
 for x in range(5)  
 for y in range(5)]

Let’s take this apart:

  • Our output expression is the tuple (x,y). That is, this list comprehension will produce a list of two-element tuples.
  • We first run over the source range(5), giving x the values 0 through 4.
  • For each value in x, we run through the source range(5), giving y the values 0 through 4.
  • The number of values in the output depends on the number of runs of  the final (second) “for” line.
  • The output, not surprisingly, will be all of the two-element tuples from (0,0) to (4,4).

Now, let’s mix things up by changing them a bit:

 [(x,y)  
  for x in range(5)  
  for y in range(x+1)]

Notice that now, the maximum value of y will vary according to the value of x. So we’ll get from (0,0) to (4,4), but we won’t see such things as (2,4) because y will never be larger than x.

Again, it’s important to understand several things here:

  • Our “for y” loop will execute once for each iteration over x.
  • In our “for y” loop, we have access to the variable x.
  • In our “for x” loop, we don’t have access to y (unless you consider the last value of y to be useful, but you really shouldn’t).
  • Our (x,y) tuple is output once for each iteration of the *final* loop, at the bottom.

Here’s another example: Assume that we have a few friends over, and that we have decided to play several games of Scrabble. Being Python programmers, we have stored our scores in a dictionary:

{'Reuven':[300, 250, 350, 400], 
 'Atara':[200, 300, 450, 150], 
 'Shikma':[250, 380, 420, 120], 
 'Amotz':[100, 120, 150, 180] }

I want to know each player’s average score, so I write a little function:

def average(scores):  
    return sum(scores) / len(scores)

If we want to find out each individual’s average score, we can use our function and a standard comprehension — in this case, a dict comprehension, to preserve the names:

 >>> { name : average(score)  
       for name, score in scores.items() }

{'Amotz': 137, 'Atara': 275, 'Reuven': 325, 'Shikma': 292}

But what if I want to get the average score, across all of the players? In such a case, I will need to grab each of the scores from inside of the inner lists. To do that, I can use a nested list comprehension:

>>> average([ one_score  
              for one_player_scores in scores.values()  
              for one_score in one_player_scores ])

257

What if I’m only interested (for whatever reason) in including scores that were above 200? As with all list comprehensions, I can use the “if” clause to weed out values that I don’t want. That condition can use any and all of the values that I have picked out of the various “for” lines:

>>> [ one_score      
      for one_player_scores in scores.values()     
      for one_score in one_player_scores
      if one_score > 200]

[300, 250, 350, 400, 300, 450, 250, 380, 420]

If I want to put these above-200 scores into a CSV file of some sort, I could do the following:

>>> ','.join([ str(one_score)  
               for one_player_scores in scores.values() 
               for one_score in one_player_scores  
               if one_score > 200])

'300,250,350,400,300,450,250,380,420'

Here’s one final example that I hope will drive these points home: Let’s assume that I have information about a hotel. The hotel has stored its information in a Python list. The list contains lists (representing rooms), and each sublist contains one or more dictionaries (representing people). Here’s our data structure:

rooms = [[{'age': 14, 'hobby': 'horses', 'name': 'A'},  
          {'age': 12, 'hobby': 'piano', 'name': 'B'},  
          {'age': 9, 'hobby': 'chess', 'name': 'C'}],  
         [{'age': 15, 'hobby': 'programming', 'name': 'D'}, 
          {'age': 17, 'hobby': 'driving', 'name': 'E'}],  
         [{'age': 45, 'hobby': 'writing', 'name': 'F'},  
          {'age': 43, 'hobby': 'chess', 'name': 'G'}]]

What are the names of the people staying at our hotel?

 >>> [ person['name']      
       for room in rooms
       for person in room ]

['A', 'B', 'C', 'D', 'E', 'F', 'G']

How about the names of people staying in our hotel who enjoy chess?

>>> [ person['name']  
      for room in rooms  
      for person in room  
      if person['hobby'] == 'chess' ]

['C', 'G']

Basically, every “for” line flattens the items over which you’re iterating by one more level, gives you access to that level in both the output expression (i.e., first line) and in the condition (i.e., optional final line).

I hope that this helps you to understand nested list comprehensions. If it did, please let me know! (And if it didn’t, please let me know that, as well!)

Get the bonus content: Nested list comprehensions

Want to understand Python’s comprehensions? Think in Excel or SQL.

Comprehensions are among the most useful constructs in Python. They merge the old, trusty “map” and “filter” functions into a single piece of compact, elegant syntax, allowing us to expression complex ideas in a minimum of code. Comprehensions are one of the most important tools in a Pythonista’s toolbox.

And yet, I have found that a very large number of Python programmers, including some experienced developers, are not completely comfortable with comprehensions. There are two reasons for this: First, it’s not obvious when to use them, and what sorts of problems they solve. The second problem, which is at least as important, is that the syntax is hard for people to remember and understand.

I’ve started to use a new explanation and introduction to comprehensions in my Python classes, and have found that it helps to lower the learning curve to some degree. In this post, I’m publicizing this explanation, in the hopes that it’ll help Python developers to understand when, where, and how to use comprehensions.

Let’s take a simple problem: I want to take a list of five integers, and get a list of their squares. If you give this problem to a new (or even intermediate) Python programmer, the answer would look something like this:

numbers = range(5)
output = [ ]
for number in numbers:
    output.append(number * number)
print(output)

Now, the thing is that this does work. (In my courses, I often use the phrase, “Unfortunately, this works.”) Often, when I talk about comprehensions, I talk about functional programming, the idea of immutable data structures, the idea that we don’t want to change things, and the benefits of thinking in terms of mapreduce.

But let’s ignore all of that, and ask a simpler question: If you were to give this problem to your accountant, how would they solve the problem?

Almost certainly, an accountant would fire up Excel, and put the numbers in a column:

A
-
0
1
2
3
4

Let’s assume that the above numbers are in the spreadsheet’s column A. The Excel user would, given this task, then tell Excel that column B should be calculated as A*A. And it would be done:

A  B
-  -
0  0
1  1
2  4
3  9
4  16

You could argue that the difference here is that Excel has a GUI, and Python doesn’t. But that’s missing the point. The real difference is that our accountant told Excel how to transform the first column into the second column, whereas our Python developer wrote a program that describe how to carry out that transformation.

We can think about this in a different way, too: Rather than solving the problem serially, as in the above for loop, the accountant is thinking in a parallel manner, applying a single expression to a large data set. The Excel user doesn’t care, or even know, the order in which the numbers are handed to the expression. The important thing is that the expression is applied once to each of the numbers, and that the final result appears in the correct order.

We might laugh at Excel, and dismiss its users as technical neophytes. And certainly, many users of Excel would deny that they possess serious programming chops. But this sort of thinking, which is so fundamental and natural to Excel users, is alien to many programmers. Which is a shame, because it allows us to express a very large number of ideas in a simple way.

To summarize this approach:

  • Think of your input as an iterable source of data
  • Think of what operation you want to apply to each element of that source
  • Get a new sequence out

That’s what the traditional “map” function does. Python does have a “map” function, but today, we typically use list comprehensions instead.

Let’s try to make this a bit more concrete, using the example that I used above: Let’s say that we have a list of five numbers, and we want to turn that list into a list of its squares. The list-comprehension syntax looks as follows:

[number * number for number in range(5) ]

Yikes. No wonder people are scared off by this syntax.  Let’s take the above syntax apart:

  • First of all, we’re going to get a list back. (It’s called a “list comprehension” for a reason.) That’s because of the square brackets, which are mandatory, and which tell Python what sort of object to create.
  • The data source will be “range(5),” which returns a list.
  • Each element in the data source will be assigned, in turn, to the iteration variable “number.”
  • We’ll invoke the operation “number * number” on each element of the data source.

In other words, we’re creating a new list, the elements of which are the result of applying our expression to each element of the source. This sounds suspiciously like what our accountant did above, using Excel: We’re telling Python what we want, and how to transform our source to that result. But how are things done internally? How is the list created? We neither know nor care.

List-comprehension syntax can be daunting for people to understand, in part because the order of the operations seems unusual. I’ve found that it can help to rewrite list comprehensions in the following way:

[number * number
 for number in range(5) ]

Yes, that’s right — I now spread list comprehensions across two lines; the first describes the operation I want to invoke, and the second line describes the data source. If this still seems unfamiliar, let’s try to bring it into a context with which you might have some experience:

[number * number           # SELECT
 for number in range(5) ]  # FROM

While they’re not directly equivalent, there are a fair number of similarities between a SELECT query in SQL, the placement of its SELECT expression and FROM clause, and our list comprehension.  The FROM clause in an SQL query describes our data source, which is typically going to be a table, but can also be a view or even the result of a function call. And the initial part of the SELECT is often the name of a column, but  can include function calls and operators.

On the one hand, the SELECT-FROM combination seems almost too simple to mention, in that you’re just retrieving a selected set of values from a data source.  On the other hand, such queries form the backbone of the database industry. In the same way, such functionality forms the backbone of many Python programs, iterating over a data structure, and plucking out part of it, transforming that part, and then returning a new list.

One of my favorite examples (and an exercise in my ebook, “Practice Makes Python“) is to take the /etc/passwd file used in Unix, and get the usernames contained within that file. /etc/passwd consists of one record per line, and the fields are separated by colons. Here are several lines from the /etc/passwd on my computer:

nobody:*:-2:-2::0:0:Unprivileged User:/var/empty:/usr/bin/false
root:*:0:0::0:0:System Administrator:/var/root:/bin/sh
daemon:*:1:1::0:0:System Services:/var/root:/usr/bin/false
_uucp:*:4:4::0:0:Unix to Unix Copy Protocol:/var/spool/uucp:/usr/sbin/uucico

We might normally think of a file as a collection of bytes, to which we give semantic meaning when we read it. But in Python, we’re encouraged to see a file as an ordered, iterable collection of lines of text. True, I can read from a file based on bytes, but it’s so common to want to read files by line that the language provides several constructs to do so.

We know that we can iterate over the lines of a file:

for line in open('/etc/passwd'):
    print(line)

This demonstrates that a file is iterable, which means that it can serve as a data source for a list comprehension. This means that the above code can be rewritten as:

[line
 for line in open('/etc/passwd')]

Again, the first line in our list comprehension represents the expression we want to apply to every element of our data source. In this case, the expression is just the line.  If we want to get the username from each of  these lines, we just need to apply the “split” method on the string, returning a list — and then retrieve index 0 from the resulting list.  For example:

[line.split(":")[0]
 for line in open('/etc/passwd')]

Again, we can think of it in terms of an SQL query:

SELECT username
FROM users

But of course, “username” in the above is a column name.  A more equivalent query to my list comprehension would be a “Users” table with an “info” column, queried as follows:

SELECT split_part(info, ':', 1)
FROM users;

Note that in this case, I’m using the built-in PostgreSQLsplit_part” operator to perform the equivalent operation to the str.split method in Python.

Remember that in the case of my SQL query, the result of a query always looks and acts like a table. The number and types of columns returned will depend on the number and types of expressions that I have in the SELECT  statement.  But the result set will have one or more columns, and zero or more rows.

In the same way, the result of a list comprehension is always going to be a list.  You can have whatever expression you want inside of the list comprehension; the expression represents one item in a list, not the list itself.

For example, let’s assume that I want to turn the usernames in /etc/passwd into a list of dictionaries. This doesn’t require a dictionary comprehension, which creates a single dictionary.  Rather, it requires a list  comprehension, in which the expression creates a dictionary.  Here’s a simple-minded such list comprehension:

[ {'name':line.split(":")[0]}
   for line in open('/etc/passwd')]

The above will work, in that it creates a list of dictionaries. And each dictionary has a single key-value pair.  But it seems a bit silly to do the above.  Rather, I’d probably want to have a dictionary containing the username and the numeric user ID, which is at index 2. I can then write:

[ {'name':line.split(":")[0], 'id':line.split(":")[2]}
for line in open('/etc/passwd')]

Again, we can think about this in terms of Excel, or even in terms of SQL: My query now produces a single column of results, but each column contains a text string. Or we can even say that the query produces two columns of results, which is not at all unusual in the world of SQL.

Let’s ignore the efficiency (or lack thereof) of invoking str.split twice in one comprehension: When I run this code on my Mac, it results in an exception, claiming that an index is out of range.

The reason is simple: I split each line into a list. But if there’s a line that doesn’t contain any : characters, it’ll be turned into a single-element list. I thus need to weed out any lines that won’t conform. Specifically, on my Mac at least, I need to remove any lines in /etc/passwd that are comments, meaning that they start with the ‘#’ character.

In the world of list comprehensions, I say the following:

[ {'name':line.split(":")[0], 'id':line.split(":")[2]}
for line in open('/etc/passwd')
if not line.startswith("#")]

Let’s extend our earlier SQL analogy further, adding the equivalent SQL syntax in comments after our Python code:

[ {'name':line.split(":")[0], 'id':line.split(":")[2]}    # SELECT
for line in open('/etc/passwd')                           # FROM
if not line.startswith("#")]                              # WHERE

Of course, when the first line of our comprehension becomes this long, it’s often a good idea to use a function, instead. And since the first line can be any legitimate Python expression, a function is often a good idea:

def get_user_info(line):
    name, passwd, id, rest = line.split(":", 3)   # max 4 fields
    return {'name':name, 'id':id}

[ get_user_info(line)             # SELECT
for line in open('/etc/passwd')   # FROM
if not line.startswith("#")]      # WHERE

A list comprehension thus gives you power similar to an SQL SELECT query — except that you’re not querying data in a table, but rather any object that conforms to Python’s iteration protocol, which includes a very  large number of built-in and custom-made objects.

Now, when would you want to use a list comprehension? And how does it differ from a for loop?

Using a list comprehension is appropriate whenever you want to transform data. That is, you have an iterable data source, and you want to create a new list whose elements are based on those of the data source. For  example, let’s assume that (for some reason) I want to find out how many times each character is used in /etc/passwd.  I can thus do the following, using collections.Counter:

from collections import Counter
counts = [Counter(line)
          for line in open('/etc/passwd')
          if not line.startswith("#")]

We know that “counts” is a list, because I used a list comprehension to create it. It is a list containing many Counter objects, one for each non-comment line in /etc/passwd. What if I want to find out what the most  popular character is in each line? I can modify my expression, asking the Counter object for the most common character:

counts = [Counter(line).most_common(1)
          for line in open('/etc/passwd')
          if not line.startswith("#")]

I can extend my expression even more, to get the most popular character from each line (inside of a two-element tuple in a one-element list):

counts = [Counter(line).most_common(1)[0][0]
          for line in open('/etc/passwd')
          if not line.startswith("#")]

And now I can find out how many times each most-popular character appears:

Counter([Counter(line).most_common(1)[0][0]
          for line in open('/etc/passwd')
          if not line.startswith("#")])

On my computer, the answer is:

Counter({':': 71, 'e': 4, 's': 1})

Meaning that in 71 non-comment lines, “:” is the most common, but in 4 lines it’s “e”, and in one line it’s “s”.  Now, could I have done this with a for loop?  Yes, of course — but because I’m dealing with iterables, and  because I’m using objects that work with such iterables, I can chain them together to get an answer in a way that doesn’t require me to tell Python how to do its job. I’m doing things like our accountant did, back at the  start of this article — I’m saying what I want, and letting Python do the hard work of dealing with this for me.

When would I use a for loop, then? The distinction is between whether you want to get a list back, and whether you want to execute a command a number of times.  If you want to build a list, and if it’s built on an iterable that already exists, then I’d say a list comprehension is almost certainly going the be the best bet.  But if you want to execute something a number of times without creating a list, then a comprehension is the a bad way to do it; you should use a “for” loop, instead.

It’s true that list comprehensions are faster than for loops. But most of the time, for loops are used for different things than list comprehensions. “for” loops shouldn’t be used when you want to turn one iterable structure into another; that’s for comprehensions. And you shouldn’t execute something (e.g., print) many times via a list comprehension, even if you can do so via a called function.  I’ve found that the dividing line between when to use a “for” loop, and when to use a comprehension, is clearly delineated in the minds of experienced Python developers, but very hazy among newcomers to the language, and to these ideas.

So, to summarize:

  • If you want to execute a command numerous times, use a “for” loop.
  • If you have an iterable, and want to create a new iterable, then a list comprehension is probably your best bet.
  • Building a list comprehension is sort of like working in Excel: You start with a set of data, and you create a new set of data. Any expression can be used to map from one to the other.  You don’t care about how Python does things behind the scenes; you just want to get your new data back.
  • A list comprehension consists of either two or three parts, which are often easier to understand if you put them on separate lines: (1) the expression, (2) the data source, and (3) an optional “if” statement.
  • These three lines are analogous to SQL’s SELECT, FROM, and WHERE clauses in a query.  And just as each of those (SELECT, FROM, and WHERE) can use arbitrary expressions, so too can Python’s list comprehensions use arbitrary expressions. A list comprehension will always return a list, though — just as a SELECT will always return a table-like result set.
  • Do you want to create a set, or perhaps a dictionary, rather than a list?  Then you can use a set comprehension or a dict comprehension. The idea is the same as everything I’ve said about list comprehensions, except that your result will be a single set or a single dictionary.

Do you find it difficult to work with list comprehensions?  If so, what’s hard for you about them?  And does the above help to make their use, and their syntax easier to remember?  I’m eager to hear your reactions, so that I can improve these explanations even further.