Skip to content

Commit 7374b18

Browse files
authored
add indexed fields support to python api (apache#1502)
* add nested struct support to python implements nested structs `col("a")['b']` * add test for indexed fields
1 parent 91ee5a4 commit 7374b18

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

python/datafusion/tests/test_dataframe.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def df():
3535
return ctx.create_dataframe([[batch]])
3636

3737

38+
@pytest.fixture
39+
def struct_df():
40+
ctx = ExecutionContext()
41+
42+
# create a RecordBatch and a new DataFrame from it
43+
batch = pa.RecordBatch.from_arrays(
44+
[pa.array([{"c": 1}, {"c": 2}, {"c": 3}]), pa.array([4, 5, 6])],
45+
names=["a", "b"],
46+
)
47+
48+
return ctx.create_dataframe([[batch]])
49+
50+
3851
def test_select(df):
3952
df = df.select(
4053
column("a") + column("b"),
@@ -153,3 +166,16 @@ def test_get_dataframe(tmp_path):
153166

154167
df = ctx.table("csv")
155168
assert isinstance(df, DataFrame)
169+
170+
171+
def test_struct_select(struct_df):
172+
df = struct_df.select(
173+
column("a")["c"] + column("b"),
174+
column("a")["c"] - column("b"),
175+
)
176+
177+
# execute and collect the first (and only) batch
178+
result = df.collect()[0]
179+
180+
assert result.column(0) == pa.array([5, 7, 9])
181+
assert result.column(1) == pa.array([-3, -3, -3])

python/src/expression.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use pyo3::PyMappingProtocol;
1819
use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol};
1920
use std::convert::{From, Into};
2021

@@ -133,3 +134,14 @@ impl PyExpr {
133134
expr.into()
134135
}
135136
}
137+
138+
#[pyproto]
139+
impl PyMappingProtocol for PyExpr {
140+
fn __getitem__(&self, key: &str) -> PyResult<PyExpr> {
141+
Ok(Expr::GetIndexedField {
142+
expr: Box::new(self.expr.clone()),
143+
key: ScalarValue::Utf8(Some(key.to_string()).to_owned()),
144+
}
145+
.into())
146+
}
147+
}

0 commit comments

Comments
 (0)