HOW TO: Update a DataFrame in Snowflake Snowpark?

Spread the love

1. Introduction

The Table.update() method in Snowpark helps in updating the rows of a table. It returns a tuple UpdateResult, representing the number of rows modified and the number of multi-joined rows modified. This method can also be used to update the rows of a DataFrame.

Syntax

Table.update(<assignments>, <condition>, [<source>])

Parameters

  • <assignments>
    A dictionary that contains key-value pairs representing columns of a DataFrame and the corresponding values with which they should be updated. The values can either be a literal value or a column object.
  • <condition>
    Represents the specific condition based on which a column should be updated. If no condition is specified, all the rows of the DataFrame will be updated.
  • <source>
    Represent another DataFrame based on which the data of the current DataFrame will be updated. The join condition between both the DataFrames should be specified in the <condition>.

2. Steps to Update a DataFrame in Snowpark

Follow the below steps to update data of a DataFrame in Snowpark using Table.update() method.

  1. Create a DataFrame with the desired data using Session.createDataFrame(). The DataFrame could be built based on an existing table or data read from a CSV file or content created within the code.
  2. Create a temporary table with the contents of the DataFrame using the DataFrameWriter class.
  3. Create a DataFrame to read the contents of the temporary table using Session.table() method.
  4. Using the Table.update() method, update the contents of the DataFrame which is created using a temporary table.
  5. Display the contents of the DataFrame to verify that the appropriate records have been updated using the DataFrame.show() method.

Temporary tables only exist within the session in which they were created and are not visible to other users or sessions. Once the session ends, the table is completely purged from the system. Therefore, temporary tables are well-suited in the scenario of updating DataFrames.

3. Demonstration of Updating all rows of a DataFrame

STEP-1: Create DataFrame

The following code creates a DataFrame df_emp which holds the the EMPLOYEES data as shown below.

#// create a DataFrame with employee data
employee_data = [
    [1,'TONY',24000,10],
    [2,'STEVE',17000,10],
    [3,'BRUCE',9000,20],
    [4,'WANDA',20000,20]
]

employee_schema = ["EMP_ID", "EMP_NAME", "SALARY", "DEPT_ID"]

df_emp =session.createDataFrame(employee_data, schema=employee_schema)
df_emp.show()

------------------------------------------------
|"EMP_ID"  |"EMP_NAME"  |"SALARY"  |"DEPT_ID"  |
------------------------------------------------
|1         |TONY        |24000     |10         |
|2         |STEVE       |17000     |10         |
|3         |BRUCE       |9000      |20         |
|4         |WANDA       |20000     |20         |
------------------------------------------------

STEP-2: Create Temporary Table

The following code creates a temporary table named tmp_emp in the Snowflake database using the contents of df_emp DataFrame.

#// create a temp table
df_emp.write.mode("overwrite").save_as_table("tmp_emp", table_type="temp")

STEP-3: Read Temporary Table

The following code creates a new DataFrame df_tmp_emp which reads the contents of temporary table tmp_emp.

#// create a DataFrame to read contents of temp table
df_tmp_emp = session.table("tmp_emp")
df_tmp_emp.show()

------------------------------------------------
|"EMP_ID"  |"EMP_NAME"  |"SALARY"  |"DEPT_ID"  |
------------------------------------------------
|1         |TONY        |24000     |10         |
|2         |STEVE       |17000     |10         |
|3         |BRUCE       |9000      |20         |
|4         |WANDA       |20000     |20         |
------------------------------------------------

STEP-4: Update DataFrame

The following code updates all the records of DataFrame df_tmp_emp by multiplying the DEPT_ID values by 10 and doubling the SALARY amounts.

#// update DEPT_ID and SALARY fields of all records
from snowflake.snowpark.types import IntegerType
from snowflake.snowpark.functions import cast

df_tmp_emp.update({"DEPT_ID": cast("DEPT_ID", IntegerType())*10, "SALARY": cast("SALARY", IntegerType())*2 })
// UpdateResult(rows_updated=4, multi_joined_rows_updated=0)

Note that we have used the cast function to convert the DEPT_ID and SALARY fields to Integer type before updating them.

STEP-5: Display Updated DataFrame

The following code displays the contents of the updated DataFrame.

#// display updated DataFrame
df_tmp_emp.show()

------------------------------------------------
|"EMP_ID"  |"EMP_NAME"  |"SALARY"  |"DEPT_ID"  |
------------------------------------------------
|1         |TONY        |48000     |100        |
|2         |STEVE       |34000     |100        |
|3         |BRUCE       |18000     |200        |
|4         |WANDA       |40000     |200        |
------------------------------------------------

4. Updating a DataFrame based on a Condition

The following code updates the salary of all employees belonging to department 100.

#// update the SALARY field of employees where DEPT_ID is 100

df_tmp_emp.update({"SALARY": cast("SALARY", IntegerType())+ 100}, df_tmp_emp["DEPT_ID"] == 100 )
// UpdateResult(rows_updated=2, multi_joined_rows_updated=0)

df_tmp_emp.show()
------------------------------------------------
|"EMP_ID"  |"EMP_NAME"  |"SALARY"  |"DEPT_ID"  |
------------------------------------------------
|1         |TONY        |48100     |100        |
|2         |STEVE       |34100     |100        |
|3         |BRUCE       |18000     |200        |
|4         |WANDA       |40000     |200        |
------------------------------------------------

5. Updating a DataFrame based on data in another DataFrame

A DataFrame can also be updated based on the data in another DataFrame using Table.update() method.

The following code updates employees’ SALARY in df_tmp_emp DataFrame where EMP_ID is equal to EMP_ID in another DataFrame df_salary.

#// update DataFrame based on data in another DataFrame

df_salary = session.createDataFrame([[1, 50000], [2, 35000]], ["EMP_ID", "SALARY"])
df_salary.show()
-----------------------
|"EMP_ID"  |"SALARY"  |
-----------------------
|1         |50000     |
|2         |35000     |
-----------------------

df_tmp_emp.update({"SALARY": df_salary["SALARY"]} , df_tmp_emp["EMP_ID"] == df_salary["EMP_ID"], df_salary)
// UpdateResult(rows_updated=2, multi_joined_rows_updated=0)

df_tmp_emp.show()
------------------------------------------------
|"EMP_ID"  |"EMP_NAME"  |"SALARY"  |"DEPT_ID"  |
------------------------------------------------
|1         |TONY        |50000     |100        |
|2         |STEVE       |35000     |100        |
|3         |BRUCE       |18000     |200        |
|4         |WANDA       |40000     |200        |
------------------------------------------------

6. Updating a DataFrame using Session.sql() Method

The Session.sql() method in Snowpark can be used to execute a SQL statement. It returns a new DataFrame representing the results of a SQL query.

Follow the below steps to update the data of a DataFrame in Snowpark using the Session.sql() method.

  1. Create a DataFrame with the desired data using Session.createDataFrame(). The DataFrame could be built based on an existing table or data read from a CSV file or content created within the code.
  2. Create a temporary table with the contents of the DataFrame using the DataFrameWriter class.
  3. Use the Session.sql() method to update the contents of the temporary table.
  4. Create a DataFrame to read the contents of the updated temporary table using the session.table() method.
  5. Display the contents of the DataFrame to verify that the appropriate records have been updated using the DataFrame.show() method.
#// create DataFrame
employee_data = [
    [1,'TONY',24000,10],
    [2,'STEVE',17000,10],
    [3,'BRUCE',9000,20],
    [4,'WANDA',20000,20]
]
employee_schema = ["EMP_ID", "EMP_NAME", "SALARY", "DEPT_ID"]
df_emp =session.createDataFrame(employee_data, schema=employee_schema)

#// create temporary table
df_emp.write.mode("overwrite").save_as_table("tmp_emp", table_type="temp")

#// update DataFrame using session.sql()
session.sql("UPDATE tmp_emp SET SALARY=70000 WHERE EMP_ID=3").collect()
// [Row(number of rows updated=1, number of multi-joined rows updated=0)]

#// create DataFrame to read contents of updated temp table
df_tmp_emp = session.table("tmp_emp")

#// display updated DataFrame
df_tmp_emp.show()
------------------------------------------------
|"EMP_ID"  |"EMP_NAME"  |"SALARY"  |"DEPT_ID"  |
------------------------------------------------
|1         |TONY        |24000     |10         |
|2         |STEVE       |17000     |10         |
|3         |BRUCE       |70000     |20         |
|4         |WANDA       |20000     |20         |
------------------------------------------------

Related Articles:

Related Articles:

Leave a Comment

Related Posts