Five years ago I started a new role and I suddenly found myself, a
staunch R fan, having to code in Python on a daily basis. Working with
data, most of my Python work involved using
pandas
, the Python data frame library,
and initially I found it quite hard and clunky to use, being used to the
silky smooth API of R’s
tidyverse
. And
you know what? It still feels hard and clunky, even now, 5 years later!
But, what seems even harder, is explaining to “Python people” what they are missing out on. From their perspective, pandas is this fantastic tool that makes Data Science in Python possible. And it is a fantastic tool, don’t get me wrong, but if you, like me, end up in many “pandas is great, but…”-type discussions and are lacking clear examples to link to; here’s a somewhat typical example of a simple analysis, built from the ground up, that flows nicely in R and the tidyverse but that becomes clunky and complicated using Python and pandas.
Let’s first step through a short analysis of purchases using R and the tidyverse. After that we’ll see how the same solution using Python and pandas compares.
Analyzing purchases
in R
We’ve been given a table of purchases
with different
amount
s, where the customer could have received a discount
and where
each purchase happened in a country
. Finance now wants to know: How
much do we typically sell in each country? Let’s read in the data and
take a look:
library(tidyverse)
purchases <- read_csv("purchases.csv")
purchases |> head()
# A tibble: 6 × 3
country amount discount
<chr> <dbl> <dbl>
1 USA 2000 10
2 USA 3500 15
3 USA 3000 20
4 Canada 120 12
5 Canada 180 18
6 Canada 3100 21
Now, without bothering with printing out the intermediate results, here’s how a quick pipeline could be built up, answering Finance’s question.
“How much do we sell..? Let’s take the total sum!”
purchases$amount |> sum()
“Ah, they wanted it by country…”
purchases |>
group_by(country) |>
summarize(total = sum(amount))
“And I guess I should deduct the discount.” (#👈/👆/👇
marks lines
that changed/moved)
purchases |>
group_by(country) |>
summarize(total = sum(amount - discount)) #👈
“Oh, and Maria asked me to remove any outliers. Let’s remove everything 10x larger than the median.”
purchases |>
filter(amount <= median(amount) * 10) |> #👈
group_by(country) |>
summarize(total = sum(amount - discount))
“I probably should use the median within each country. Prices are quite different across the globe…”
purchases |>
group_by(country) |> #👆
filter(amount <= median(amount) * 10) |> #👇
summarize(total = sum(amount - discount))
# A tibble: 11 × 2
country total
<chr> <dbl>
1 Australia 540
2 Brazil 414
3 Canada 270
4 France 450
5 Germany 513
6 India 648
7 Italy 567
8 Japan 621
9 Spain 594
10 UK 432
11 USA 8455
“And we’re done, let’s go for second breakfast!”
Analyzing purchases
in Python
We’re now going to take a look at how this little analysis would look in
Python and pandas. One complication here is that pandas can be written
in many different styles; it’s not like in the tidyverse where there’s
often one obvious way to do something. Here we’re opting for writing
pandas using the fluent method chaining API, as opposed to using the
more “imperative” approach that results in a lot of repeats of df
and
statements like
df[df["this"] == "that"] = calc_some(df["other_thing"])
. We’re also
opting for always returning a table with all the data in the data frame
proper. We don’t want data hidden away in the index (that is, pandas’
really advanced system for row and column names). Having data in the
index is generally annoying when one wants to process the data further
or when turning the data into plots.
Again, let’s step through the R version of the analysis, and below let’s
write the corresponding pandas code. Again, #👈/👆/👇
marks lines that
have changed/moved.
Reading in the data
# R
library(tidyverse)
purchases <- read_csv("purchases.csv")
purchases |> head()
# A tibble: 6 × 3
country amount discount
<chr> <dbl> <dbl>
1 USA 2000 10
2 USA 3500 15
3 USA 3000 20
4 Canada 120 12
5 Canada 180 18
6 Canada 3100 21
This is basically the same in pandas. So far so good!
# Python
import pandas as pd
purchases = pd.read_csv("purchases.csv")
purchases.head()
country amount discount
0 USA 2000 10
1 USA 3500 15
2 USA 3000 20
3 Canada 120 12
4 Canada 180 18
“How much do we sell..? Let’s take the total sum!”
# R
purchases$amount |> sum()
[1] 17210
This is also similar in pandas:
# Python
purchases["amount"].sum()
17210
(However, note that this method,
pandas.Series.sum()
,
is not the same as
pandas.DataFrame.sum()
,
or
numpy.sum()
,
or the built-in
sum
function, each of which has different arguments and behaviors. In R,
it’s always the same built-in sum()
function.)
“Ah, they wanted it by country…”
purchases |>
group_by(country) |>
summarize(total = sum(amount))
# A tibble: 11 × 2
country total
<chr> <dbl>
1 Australia 600
2 Brazil 460
3 Canada 3400
4 France 500
5 Germany 570
6 India 720
7 Italy 630
8 Japan 690
9 Spain 660
10 UK 480
11 USA 8500
This is also very similar in Python:
# Python
(purchases
.groupby("country")["amount"]
.sum()
)
country
Australia 600
Brazil 460
Canada 3400
France 500
Germany 570
India 720
Italy 630
Japan 690
Spain 660
UK 480
USA 8500
Name: amount, dtype: int64
Ah, but here we actually need to do more work. The output has now turned
into a
pandas.Series
,
not a data frame, and country
got moved to the index. We can solve
this by using .reset_index()
. Also, we’re not happy with the amount
column name, but .sum()
does not allow us to specify a different name.
Instead of .sum()
we can use the .agg()
method to get around this:
# Python
(purchases
.groupby("country")
.agg(total=("amount", "sum")) #👈
.reset_index() #👈
)
country total
0 Australia 600
1 Brazil 460
2 Canada 3400
3 France 500
4 Germany 570
5 India 720
6 Italy 630
7 Japan 690
8 Spain 660
9 UK 480
10 USA 8500
(Another thing that’s new here is that we now have to pass the sum
method as a "sum"
string.)
“And I guess I should deduct the discount.”
A tiny change in R…
# R
purchases |>
group_by(country) |>
summarize(total = sum(amount - discount)) #👈
# A tibble: 11 × 2
country total
<chr> <dbl>
1 Australia 540
2 Brazil 414
3 Canada 3349
4 France 450
5 Germany 513
6 India 648
7 Italy 567
8 Japan 621
9 Spain 594
10 UK 432
11 USA 8455
… but a large change in Python. The .agg()
method can only aggregate
single columns. When this is not the case we have to fall back on
.apply()
, which can handle any type of aggregation. As we want to
avoid a column with the enigmatic name 0
, we also have to use
.rename()
to get back to total
, again.
# Python
(purchases
.groupby("country")
.apply(lambda df: (df["amount"] - df["discount"]).sum()) #👈
.reset_index()
.rename(columns={0: "total"}) #👈
)
country total
0 Australia 540
1 Brazil 414
2 Canada 3349
3 France 450
4 Germany 513
5 India 648
6 Italy 567
7 Japan 621
8 Spain 594
9 UK 432
10 USA 8455
“Oh, and Maria asked me to remove any outliers.”
purchases |>
filter(amount <= median(amount) * 10) |> #👈
group_by(country) |>
summarize(total = sum(amount - discount))
# A tibble: 11 × 2
country total
<chr> <dbl>
1 Australia 540
2 Brazil 414
3 Canada 270
4 France 450
5 Germany 513
6 India 648
7 Italy 567
8 Japan 621
9 Spain 594
10 UK 432
11 USA 1990
This is also a simple change in Python, using .query()
:
# Python
(purchases
.query("amount <= amount.median() * 10") #👈
.groupby("country")
.apply(lambda df: (df["amount"] - df["discount"]).sum())
.reset_index()
.rename(columns={0: "total"})
)
country total
0 Australia 540
1 Brazil 414
2 Canada 270
3 France 450
4 Germany 513
5 India 648
6 Italy 567
7 Japan 621
8 Spain 594
9 UK 432
10 USA 1990
(But why is it called .query()
when it filters? And why can’t we use
DataFrame.filter()
instead? Ah, that only filters on the index names.
And why do we suddenly have to pass in Python code as a string? Ah, it’s
actually not Python, but a language that’s similar to Python. Of
course, all these questions have explanations, yet I still can never
really remember what I’m allowed to put in a .query()
string. Instead
of .query()
we could use .loc[]
, but then we need to do a fair bit
of typing:
.loc[lambda df: df["amount"] <= df["amount"].median() * 10]
. Compare
that to the R version filter(amount <= median(amount) * 10)
)
“I probably should use the median within each country”
# R
purchases |>
group_by(country) |> #👆
filter(amount <= median(amount) * 10) |> #👇
summarize(total = sum(amount - discount))
# A tibble: 11 × 2
country total
<chr> <dbl>
1 Australia 540
2 Brazil 414
3 Canada 270
4 France 450
5 Germany 513
6 India 648
7 Italy 567
8 Japan 621
9 Spain 594
10 UK 432
11 USA 8455
What’s just swapping two lines in R, becomes much more involved in
Python. The reason for this is that .groupby()
doesn’t return a
pandas.DataFrame
, it returns a pandas.api.typing.DataFrameGroupBy
object, which doesn’t have the same set of methods as a regular data
frame. Especially, it doesn’t have .query()
nor .loc[]
. There are
two solutions here: A first solution is that we fall back on .apply()
,
this time returning a filtered version of each group, but then we also
need to remove the country
index completely with
.reset_index(drop=True)
as the filtered purchases
already has a
country
column:
# Python
(purchases
.groupby("country") #👈
.apply(lambda df: df[df["amount"] <= df["amount"].median() * 10]) #👈
.reset_index(drop=True) #👈
.groupby("country")
.apply(lambda df: (df["amount"] - df["discount"]).sum())
.reset_index()
.rename(columns={0: "total"})
)
country total
0 Australia 540
1 Brazil 414
2 Canada 270
3 France 450
4 Germany 513
5 India 648
6 Italy 567
7 Japan 621
8 Spain 594
9 UK 432
10 USA 8455
(The fact that grouped and regular pandas data frames have different
APIs is a constant source of confusion, to me. One example of this is
.filter()
, where DataFrameGroupBy.filter()
does something
completely different from DataFrame.filter()
. And none of them
actually filter away values!)
A second solution is that we first calculate the median amount
per
country
and assign it to each row in purchases
. The upside is now
that we can continue to use .query()
, but at the cost of introducing
both .assign()
and .transform()
into the mix.
# Python
(purchases
.assign(country_median=lambda df: #👈
df.groupby("country")["amount"].transform("median") #👈
)
.query("amount <= country_median * 10") #👈
.groupby("country")
.apply(lambda df: (df["amount"] - df["discount"]).sum())
.reset_index()
.rename(columns={0: "total"})
)
country total
0 Australia 540
1 Brazil 414
2 Canada 270
3 France 450
4 Germany 513
5 India 648
6 Italy 567
7 Japan 621
8 Spain 594
9 UK 432
10 USA 8455
Compare this with, again, the final R solution:
purchases |>
group_by(country) |>
filter(amount <= median(amount) * 10) |>
summarize(total = sum(amount - discount))
This solution is not only shorter but also contains less ‘boilerplate’
code, such as lambda
, reset_index
, etc. The journey to the R
solution was more straight forward and we could build it up one step at
a time. With pandas, we often had to backtrack and switch out parts of
the intermediate solution.
So, what’s your point?
My point is that, if you’re a “Python person”, then pandas is a great tool and people with extensive R experience may find working with pandas frustrating for valid reasons. Show them some compassion!
You might think my purchases
analysis was just a little toy example,
selected to highlight the clunkiness of the pandas API. And yes,
partially, but my experience is that with larger, real-world code the
problems with the pandas API, outlined in this post, remains. That is,
pandas feels clunky when coming from R because:
- The naming of methods and arguments is often confusing (
.filter()
doesn’t filter values. Will.sum(axis=1)
sum the rows or the columns?) - Different methods are available for grouped and non-grouped data
frames and methods with the same name can do very different things
(for example
DataFrame.filter()
andDataFrameGroupBy.filter()
). - Many convenience function are missing from pandas, which means you’ll
have to code them from scratch. For instance, moving the
year
to be the first column isdf |> relocate(year)
in the tidyverse. It’sdf[["year"] + [col for col in df.columns if col != "year"]]
in pandas. - Pandas will constantly move columns into the index, and you’ll have to
work hard to get that data out again. You’ll be typing
.reset_index()
many many times.