Skip to content

Commit

Permalink
feat(sql): adds session sql for leveraging attached catalogs (#3860)
Browse files Browse the repository at this point in the history
This PR adds a `session.sql()` API which allows leveraging attached
catalogs and management of session state for name resolution. This
allows to properly resolve qualified names across various catalog
implementations. For now, this is separate from the global session as to
not introduce any breaking changes.



**Example**

```python
    # attach some external catalogs
    cat_1 = _create_catalog("cat_1", tmpdir)
    cat_2 = _create_catalog("cat_2", tmpdir)
    # attach to a new session
    sess = Session()
    sess.attach_catalog(cat_1, alias="cat_1")
    sess.attach_catalog(cat_2, alias="cat_2")
    return sess

    #
    # unqualified should only work in cat_1 and ns_1
    sess.set_catalog("cat_1")
    sess.set_namespace("ns_1")
    assert sess.sql("select * from tbl_cat_1_11") is not None
    assert sess.sql("select * from tbl_cat_1_12") is not None
    #
    # schema-qualified and  should still work.
    assert sess.sql("select * from ns_1.tbl_cat_1_11") is not None
    assert sess.sql("select * from ns_2.tbl_cat_1_21") is not None
    #
    # catalog-qualified should still work.
    assert sess.sql("select * from cat_1.ns_1.tbl_cat_1_11") is not None
    assert sess.sql("select * from cat_1.ns_2.tbl_cat_1_21") is not None
    #
    # err! should not find unqualified things from ns_2
    with pytest.raises(Exception, match="not found"):
        sess.sql("select * from tbl_cat_1_21")
    with pytest.raises(Exception, match="not found"):
        sess.sql("select * from tbl_cat_1_22")
    #
    # find in cat_1.ns_2 only if schema-qualified
    assert sess.sql("select * from ns_2.tbl_cat_1_21") is not None
    assert sess.sql("select * from ns_2.tbl_cat_1_22") is not None
    #
    # find in cat_2 only if catalog-qualified
    assert sess.sql("select * from cat_2.ns_1.tbl_cat_2_11") is not None
    assert sess.sql("select * from cat_2.ns_2.tbl_cat_2_21") is not None
```
  • Loading branch information
rchowell authored Feb 26, 2025
1 parent 3a3c66f commit 0080a7e
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 11 deletions.
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ class SQLFunctionStub:
@property
def arg_names(self) -> list[str]: ...

def plan_sql(source: str, session: PySession, config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ...
def sql(sql: str, catalog: PyCatalog, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ...
def sql_expr(sql: str) -> PyExpr: ...
def list_sql_functions() -> list[SQLFunctionStub]: ...
Expand Down
22 changes: 15 additions & 7 deletions daft/session.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from daft.catalog import Catalog, Identifier, Table, TableSource
from daft.daft import PySession

if TYPE_CHECKING:
from daft.dataframe import DataFrame

from daft.context import get_context
from daft.daft import PySession, plan_sql
from daft.dataframe import DataFrame
from daft.logical.builder import LogicalPlanBuilder

__all__ = [
"Session",
Expand Down Expand Up @@ -57,6 +54,17 @@ def _from_env() -> Session:
# todo session builders, raise if DAFT_SESSION=0
return Session()

###
# exec
###

def sql(self, sql: str) -> DataFrame:
"""Executes the SQL statement using this session."""
py_sess = self._session
py_config = get_context().daft_planning_config
py_builder = plan_sql(sql, py_sess, py_config)
return DataFrame(LogicalPlanBuilder(py_builder))

###
# attach & detach
###
Expand Down
17 changes: 15 additions & 2 deletions src/daft-catalog/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,24 @@ impl PyTableWrapper {

impl Table for PyTableWrapper {
fn get_schema(&self) -> SchemaRef {
todo!()
todo!("get_schema")
}

fn get_logical_plan(&self) -> Result<LogicalPlanRef> {
todo!()
Python::with_gil(|py| {
// table = 'python table object'
let table = self.0.bind(py);
// df = table.read()
let df = table.call_method0("read")?;
// builder = df._builder._builder
let builder = df.getattr("_builder")?.getattr("_builder")?;
// builder as PyLogicalPlanBuilder
let builder = builder
.downcast::<PyLogicalPlanBuilder>()
.expect("downcast to PyLogicalPlanBuilder failed")
.borrow();
Ok(builder.builder.plan.clone())
})
}

fn to_py(&self, py: Python<'_>) -> PyResult<PyObject> {
Expand Down
6 changes: 6 additions & 0 deletions src/daft-session/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ impl PySession {
}
}

impl From<&PySession> for Session {
fn from(sess: &PySession) -> Self {
sess.0.clone()
}
}

pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_class::<PySession>()?;
Ok(())
Expand Down
1 change: 1 addition & 0 deletions src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use pyo3::prelude::*;
#[cfg(feature = "python")]
pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_class::<python::PyCatalog>()?;
parent.add_function(wrap_pyfunction!(python::plan_sql, parent)?)?;
parent.add_function(wrap_pyfunction!(python::sql, parent)?)?;
parent.add_function(wrap_pyfunction!(python::sql_expr, parent)?)?;
parent.add_function(wrap_pyfunction!(python::list_sql_functions, parent)?)?;
Expand Down
15 changes: 13 additions & 2 deletions src/daft-sql/src/python.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, rc::Rc, sync::Arc};

use common_daft_config::PyDaftPlanningConfig;
use daft_catalog::TableSource;
use daft_dsl::python::PyExpr;
use daft_logical_plan::{LogicalPlan, LogicalPlanBuilder, PyLogicalPlanBuilder};
use daft_session::Session;
use daft_session::{python::PySession, Session};
use pyo3::prelude::*;

use crate::{functions::SQL_FUNCTIONS, planner::SQLPlanner};
Expand Down Expand Up @@ -34,6 +34,17 @@ impl SQLFunctionStub {
}
}

#[pyfunction]
pub fn plan_sql(
sql: &str,
session: &PySession,
config: PyDaftPlanningConfig,
) -> PyResult<PyLogicalPlanBuilder> {
let sess = Rc::new(session.into());
let plan = SQLPlanner::new(sess).plan_sql(sql)?;
Ok(LogicalPlanBuilder::new(plan, Some(config.config)).into())
}

#[pyfunction]
pub fn sql(
sql: &str,
Expand Down
109 changes: 109 additions & 0 deletions tests/sql/test_sess_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pyarrow as pa
import pytest

from daft import Catalog, DataFrame, Session


def assert_eq(actual: DataFrame, expect: DataFrame):
assert actual.to_pydict() == expect.to_pydict()


def _create_catalog(name: str, tmpdir: str):
from pyiceberg.catalog.sql import SqlCatalog

catalog = SqlCatalog(
name,
**{
"uri": f"sqlite:///{tmpdir}/pytest_sql_{name}.db",
"warehouse": f"file://{tmpdir}",
},
)
# using naming convention "tbl_<catalog>_<namespace #><table #>"
# which let's us know where the table is coming from with only its name
catalog.create_namespace("ns_1")
catalog.create_namespace("ns_2")
catalog.create_table(f"ns_1.tbl_{name}_11", pa.schema([]))
catalog.create_table(f"ns_1.tbl_{name}_12", pa.schema([]))
catalog.create_table(f"ns_2.tbl_{name}_21", pa.schema([]))
catalog.create_table(f"ns_2.tbl_{name}_22", pa.schema([]))
return Catalog.from_iceberg(catalog)


@pytest.fixture()
def sess(tmpdir) -> Session:
# create some tmp catalogs
cat_1 = _create_catalog("cat_1", tmpdir)
cat_2 = _create_catalog("cat_2", tmpdir)
# attach to a new session
sess = Session()
sess.attach_catalog(cat_1, alias="cat_1")
sess.attach_catalog(cat_2, alias="cat_2")
return sess


# chore: consider reducing test verbosity
# def try_resolve(sess: Session, ident) -> DataFrame:
# sess.sql(f"select * from {ident}")

# chore: consider reducing test verbosity
# def assert_resolve(sess: Session, ident: str):
# assert try_resolve(ident) is not None


def test_catatalog_qualified_idents(sess: Session):
# catalog-qualified
assert sess.sql("select * from cat_1.ns_1.tbl_cat_1_11") is not None
assert sess.sql("select * from cat_2.ns_1.tbl_cat_2_11") is not None


def test_schema_qualified_idents(sess: Session):
#
# schema-qualified should work for cat_1
sess.set_catalog("cat_1")
assert sess.sql("select * from ns_1.tbl_cat_1_11") is not None
assert sess.sql("select * from ns_2.tbl_cat_1_21") is not None
#
# catalog-qualified should still work.
assert sess.sql("select * from cat_1.ns_1.tbl_cat_1_11") is not None
assert sess.sql("select * from cat_1.ns_2.tbl_cat_1_21") is not None
#
# err! should not find things from cat_2
with pytest.raises(Exception, match="not found"):
sess.sql("select * from ns_1.tbl_cat_2_11")
with pytest.raises(Exception, match="not found"):
sess.sql("select * from ns_2.tbl_cat_2_21")
#
# find in cat_2 only if catalog-qualified
assert sess.sql("select * from cat_2.ns_1.tbl_cat_2_11") is not None
assert sess.sql("select * from cat_2.ns_2.tbl_cat_2_21") is not None


def test_unqualified_idents(sess: Session):
#
# unqualified should only work in cat_1 and ns_1
sess.set_catalog("cat_1")
sess.set_namespace("ns_1")
assert sess.sql("select * from tbl_cat_1_11") is not None
assert sess.sql("select * from tbl_cat_1_12") is not None
#
# schema-qualified and should still work.
assert sess.sql("select * from ns_1.tbl_cat_1_11") is not None
assert sess.sql("select * from ns_2.tbl_cat_1_21") is not None
#
# catalog-qualified should still work.
assert sess.sql("select * from cat_1.ns_1.tbl_cat_1_11") is not None
assert sess.sql("select * from cat_1.ns_2.tbl_cat_1_21") is not None
#
# err! should not find unqualified things from ns_2
with pytest.raises(Exception, match="not found"):
sess.sql("select * from tbl_cat_1_21")
with pytest.raises(Exception, match="not found"):
sess.sql("select * from tbl_cat_1_22")
#
# find in cat_1.ns_2 only if schema-qualified
assert sess.sql("select * from ns_2.tbl_cat_1_21") is not None
assert sess.sql("select * from ns_2.tbl_cat_1_22") is not None
#
# find in cat_2 only if catalog-qualified
assert sess.sql("select * from cat_2.ns_1.tbl_cat_2_11") is not None
assert sess.sql("select * from cat_2.ns_2.tbl_cat_2_21") is not None

0 comments on commit 0080a7e

Please sign in to comment.