Aggregate Functions in Snowflake Snowpark

Spread the love

1. Introduction

Aggregate functions perform a calculation on a set of values and return a single value. These functions are often used in conjunction with the GROUP BY clause to perform calculations on groups of rows.

To know the list of all the supported aggregate functions in Snowflake, refer to Snowflake Documentation.

In this article, we will explore how to use aggregate functions in Snowflake Snowpark Python.

2. Aggregate Functions in Snowpark

The DataFrame.agg method in Snowpark is used to aggregate the data in a DataFrame. This method accepts any valid Snowflake aggregate function names as input to perform calculations on multiple rows and produce a single output.

There are several ways the DataFrame columns can be passed to DataFrame.agg method to perform aggregate calculations.

  1. A Column object
  2. A tuple where the first element is a column object or a column name and the second element is the name of the aggregate function
  3. A list of the above
  4. A dictionary that maps column name to an aggregate function name.

3. Demonstration of Aggregate Functions using DataFrame.agg Method in Snowpark

Follow the below steps to perform Aggregate Calculations using DataFrame.agg Method.

  • STEP-1: Establish a connection with Snowflake from Snowpark using the Session class.
  • STEP-2: Import all the required aggregate functions (min, max, sum, etc.,) from the snowflake.snowpark.functions package.
  • STEP-3: Create a DataFrame that holds the data on which aggregate functions are to be applied.
  • STEP-4: Implement aggregate calculations on the DataFrame using the DataFrame.agg method.

Demonstration

Consider the EMPLOYEE data below for the demonstration of the implementation of the Aggregate functions in Snowpark.

#// Creating a DataFrame with EMPLOYEE data
employee_data = [
  [1,'TONY',24000],
  [2,'STEVE',17000],
  [3,'BRUCE',9000],
  [4,'WANDA',20000],
  [5,'VICTOR',12000],
  [6,'STEPHEN',10000]
]

employee_schema = ["EMP_ID", "EMP_NAME", "SALARY"]

df_employee =session.createDataFrame(employee_data, schema=employee_schema)
df_employee.show()

------------------------------------
|"EMP_ID"  |"EMP_NAME"  |"SALARY"  |
------------------------------------
|1         |TONY        |24000     |
|2         |STEVE       |17000     |
|3         |BRUCE       |9000      |
|4         |WANDA       |20000     |
|5         |VICTOR      |12000     |
|6         |STEPHEN     |10000     |
------------------------------------

3.1. Passing a DataFrame Column Object

Import all the necessary aggregate function methods from the snowflake.snowpark.functions package before performing aggregate calculations as shown below.

#// Importing the Aggregate Function methods
from snowflake.snowpark.functions import col, min, max, avg

#// Passing a Column object to DataFrame.agg method
df_employee.agg(max("SALARY"), min("SALARY")).show()
---------------------------------
|"MAX(SALARY)"  |"MIN(SALARY)"  |
---------------------------------
|24000          |9000           |
---------------------------------

df_employee.agg(max(col("SALARY")), min(col("SALARY"))).show()
---------------------------------
|"MAX(SALARY)"  |"MIN(SALARY)"  |
---------------------------------
|24000          |9000           |
---------------------------------

3.2. Passing a Tuple with Column Name and Aggregate Function

#// Passing a tuple with column name and aggregate function to DataFrame.agg method
df_employee.agg(("SALARY", "max"), ("SALARY", "min")).show()
---------------------------------
|"MAX(SALARY)"  |"MIN(SALARY)"  |
---------------------------------
|24000          |9000           |
---------------------------------

3.3. Passing a List of Column Objects and Tuple

#// Passing a list of the values
df_employee.agg([("SALARY", "min"), ("SALARY", "max"), avg(col("SALARY"))]).show()
-------------------------------------------------
|"MIN(SALARY)"  |"MAX(SALARY)"  |"AVG(SALARY)"  |
-------------------------------------------------
|9000           |24000          |15333.333333   |
-------------------------------------------------

3.4. Passing a dictionary Mapping Column Name to Aggregate Function

#// Passing a dictionary mapping column name to aggregate function
df_employee.agg({"SALARY": "min"}).show()
-----------------
|"MIN(SALARY)"  |
-----------------
|9000           |
-----------------

4. Aggregate Functions using DataFrame.select method in Snowpark

The DataFrame.select method can be used to return a new DataFrame with the specified Column expressions as output. Aggregate functions can be utilized as column expressions to select and process data from a DataFrame.

#// Aggregate functions using select method
df_employee.select(min("SALARY"), max("SALARY")).show()
-----------------------------------------
|"MIN(""SALARY"")"  |"MAX(""SALARY"")"  |
-----------------------------------------
|9000               |24000              |
-----------------------------------------

5. Renaming the Return Aggregate Fields

The output fields from the Aggregate Functions can be renamed to new column names using Column._as or Column.alias methods as shown below.

#// Renaming column names
df_employee.agg(min("SALARY").as_("min_sal"), max("SALARY").alias("max_sal")).show()
-------------------------
|"MIN_SAL"  |"MAX_SAL"  |
-------------------------
|9000       |24000      |
-------------------------

df_employee.select(min("SALARY").as_("MIN_SAL"), max("SALARY").alias("MAX_SAL")).show()
-------------------------
|"MIN_SAL"  |"MAX_SAL"  |
-------------------------
|9000       |24000      |
-------------------------

6. Passing Return Value of an Aggregate Function as an Input

Let us understand this with a simple example. Consider the requirement is to get the employee details with max salary. This can be accomplished using the below SQL query.

-- Get employee details with MAX Salary
SELECT * FROM EMPLOYEES WHERE SALARY IN(
SELECT MAX(SALARY) FROM EMPLOYEES) ;


In the above example,

  • The Max Salary amount in the table is calculated using the aggregate function.
  • The Calculated Salary amount is passed as a filter to the employees table to extract the entire employee details.

Let us understand how the same can be achieved in Snowpark.

The DataFrame.collect method in Snowpark is used to collect all the return values after executing all the defined calculations on a DataFrame. The output is stored in the form of a list of Row objects.

In the following code, the max salary is calculated using the DataFrame.agg method and the return value is stored into a variable using DataFrame.collect method.

max_sal = df_employee.agg(max("SALARY").alias("MAX_SALARY")).collect()

The following code shows that the variable max_sal is of a type list and the value stored in it.

type(max_sal)
-----------------
|<class 'list'> |
-----------------

print(max_sal)
----------------------------
|[Row(MAX_SALARY=24000)]   |
----------------------------

Extract the max salary amount from the list object as shown below.

max_sal = df_max_sal[0]['MAX_SALARY']

type(max_sal)
-----------------
|<class 'int'>  |
-----------------

print(max_sal)
----------
|24000   |
----------

The DataFrame.filter method filters rows from a DataFrame based on the specified conditional expression (similar to WHERE in SQL).

The following code extracts the employee details with max salary.

#// Get employee details with max salary
df_employee.filter(col("SALARY") == max_sal).show()
------------------------------------
|"EMP_ID"  |"EMP_NAME"  |"SALARY"  |
------------------------------------
|1         |TONY        |24000     |
------------------------------------

Subscribe to our Newsletter !!

Related Articles:

Leave a Comment

Related Posts