I want to get all the tables names from a sql query in Spark using Scala.
Lets say user sends a SQL query which looks like:
select * from table_1 as
def __sqlparse2table(self, query):
'''
@description: get table name from table
'''
plan = self.spark._jsparkSession.sessionState().sqlParser().parsePlan(query)
plan_string = plan.toString().replace('`.`', '.')
unr = re.findall(r"UnresolvedRelation `(.*?)`", plan_string)
cte = re.findall(r"CTE \[(.*?)\]", plan.toString())
cte = [tt.strip() for tt in cte[0].split(',')] if cte else cte
schema = set()
tables = set()
for table_name in unr:
if table_name not in cte:
schema.update([table_name.split('.')[0]])
tables.update([table_name])
return schema, tables