support external diff column in postgres snapshot connector (#8979)

GitOrigin-RevId: e1c87bea13a6a522878122b40c3e460d3295b573
This commit is contained in:
Sergey Kulik 2025-07-10 18:16:37 +02:00 committed by Manul from Pathway
parent 4f465a7597
commit 9e9d8fc4e8
7 changed files with 183 additions and 10 deletions

View File

@ -409,3 +409,78 @@ def test_psql_json_datetimes(postgres):
assert result["a"] == expected
assert result["b"] == expected
assert result["c"] == expected
def test_psql_external_diff_column(tmp_path, postgres):
class InputSchema(pw.Schema):
name: str = pw.column_definition(primary_key=True)
count: int
price: float
available: bool
external_diff: int
input_path = tmp_path / "input.txt"
output_table = postgres.create_table(InputSchema, used_for_output=True)
def _run(test_items: list[dict]) -> None:
G.clear()
with open(input_path, "w") as f:
for test_item in test_items:
f.write(json.dumps(test_item) + "\n")
table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static")
pw.io.postgres.write_snapshot(
table,
POSTGRES_SETTINGS,
output_table,
["name"],
_external_diff_column=table.external_diff,
)
run()
test_items = [
{
"name": "Milk",
"count": 500,
"price": 1.5,
"available": False,
"external_diff": 1,
},
{
"name": "Water",
"count": 600,
"price": 0.5,
"available": True,
"external_diff": 1,
},
]
_run(test_items)
rows = postgres.get_table_contents(output_table, InputSchema.column_names())
rows.sort(key=lambda item: (item["name"], item["available"]))
assert rows == test_items
# Also test that the junk data in the additional columns would not break deletion
new_test_items = [
{
"name": "Milk",
"count": -1,
"price": -1.0,
"available": True,
"external_diff": -1,
}
]
_run(new_test_items)
rows = postgres.get_table_contents(
output_table, InputSchema.column_names(), ("name", "available")
)
expected_rows = [
{
"name": "Water",
"count": 600,
"price": 0.5,
"available": True,
"external_diff": 1,
},
]
assert rows == expected_rows

View File

@ -488,3 +488,14 @@ def _prepare_s3_connection_engine_settings(
if aws_s3_settings is None:
return None
return aws_s3_settings.settings
def get_column_index(table: Table, column: ColumnReference | None) -> int | None:
if column is None:
return None
if column._table != table:
raise ValueError(f"The column {column} doesn't belong to the target table")
for index, table_column in enumerate(table._columns):
if table_column == column.name:
return index
raise RuntimeError(f"The column {column} is not found in the table {table}")

View File

@ -4,12 +4,13 @@ from __future__ import annotations
from typing import Iterable, Literal
from pathway.internals import api, datasink
from pathway.internals import api, datasink, dtype
from pathway.internals._io_helpers import _format_output_value_fields
from pathway.internals.expression import ColumnReference
from pathway.internals.runtime_type_check import check_arg_types
from pathway.internals.table import Table
from pathway.internals.trace import trace_user_frame
from pathway.io._utils import get_column_index
def _connection_string_from_settings(settings: dict):
@ -159,6 +160,7 @@ def write_snapshot(
init_mode: Literal["default", "create_if_not_exists", "replace"] = "default",
name: str | None = None,
sort_by: Iterable[ColumnReference] | None = None,
_external_diff_column: ColumnReference | None = None,
) -> None:
"""Maintains a snapshot of a table within a Postgres table.
@ -232,11 +234,19 @@ def write_snapshot(
table_name=table_name,
sql_writer_init_mode=_init_mode_from_str(init_mode),
)
if (
_external_diff_column is not None
and _external_diff_column._column.dtype != dtype.INT
):
raise ValueError("_external_diff_column can only have an integer type")
external_diff_column_index = get_column_index(table, _external_diff_column)
data_format = api.DataFormat(
format_type="sql_snapshot",
key_field_names=primary_key,
value_fields=_format_output_value_fields(table),
table_name=table_name,
external_diff_column_index=external_diff_column_index,
)
table.to(

View File

@ -9,6 +9,7 @@ from pathway.internals.expression import ColumnReference
from pathway.internals.runtime_type_check import check_arg_types
from pathway.internals.table import Table
from pathway.internals.trace import trace_user_frame
from pathway.io._utils import get_column_index
@check_arg_types
@ -142,13 +143,7 @@ def write(
f"designated_timestamp is passed, but designated_timestamp_policy is {designated_timestamp_policy}"
)
designated_timestamp_policy = "use_column"
if designated_timestamp._table != table:
raise ValueError(
f"The column {designated_timestamp} doesn't belong to the target table"
)
for index, column in enumerate(table._columns):
if column == designated_timestamp.name:
designated_timestamp_index = index
designated_timestamp_index = get_column_index(table, designated_timestamp)
elif designated_timestamp_policy is None:
designated_timestamp_policy = "use_now"

View File

@ -462,6 +462,9 @@ pub enum FormatterError {
#[error(transparent)]
SchemaRepository(#[from] SchemaRepositoryError),
#[error("incorrect external diff value: {0}")]
IncorrectDiffColumnValue(Value),
}
pub trait Formatter: Send {
@ -1877,6 +1880,7 @@ pub struct PsqlSnapshotFormatter {
key_field_positions: Vec<usize>,
value_field_positions: Vec<usize>,
external_diff_column_index: Option<usize>,
}
impl PsqlSnapshotFormatter {
@ -1884,8 +1888,9 @@ impl PsqlSnapshotFormatter {
table_name: String,
mut key_field_names: Vec<String>,
mut value_field_names: Vec<String>,
external_diff_column_index: Option<usize>,
) -> Result<PsqlSnapshotFormatter, PsqlSnapshotFormatterError> {
let mut field_positions = HashMap::<String, usize>::new();
let mut field_positions = HashMap::<String, usize>::with_capacity(value_field_names.len());
for (index, field_name) in value_field_names.iter_mut().enumerate() {
if field_positions.contains_key(field_name) {
return Err(PsqlSnapshotFormatterError::RepeatedValueField(take(
@ -1914,6 +1919,7 @@ impl PsqlSnapshotFormatter {
key_field_positions,
value_field_positions,
external_diff_column_index,
})
}
}
@ -1932,7 +1938,20 @@ impl Formatter for PsqlSnapshotFormatter {
let mut result = Vec::new();
if diff == 1 {
let effective_diff: isize =
if let Some(external_diff_column_index) = self.external_diff_column_index {
let value = &values[external_diff_column_index];
match value {
Value::Int(x) if *x == -1 || *x == 1 => (*x)
.try_into()
.expect("the values from {-1, 1} must convert into isize"),
_ => return Err(FormatterError::IncorrectDiffColumnValue(value.clone())),
}
} else {
diff
};
if effective_diff == 1 {
let update_pairs = self
.value_field_positions
.iter()

View File

@ -4686,6 +4686,7 @@ pub struct DataFormat {
schema_registry_settings: Option<PySchemaRegistrySettings>,
subject: Option<String>,
designated_timestamp_policy: Option<String>,
external_diff_column_index: Option<usize>,
}
#[pymethods]
@ -4889,6 +4890,7 @@ impl DataFormat {
schema_registry_settings = None,
subject = None,
designated_timestamp_policy = None,
external_diff_column_index = None,
))]
#[allow(clippy::too_many_arguments)]
fn new(
@ -4907,6 +4909,7 @@ impl DataFormat {
schema_registry_settings: Option<PySchemaRegistrySettings>,
subject: Option<String>,
designated_timestamp_policy: Option<String>,
external_diff_column_index: Option<usize>,
) -> Self {
DataFormat {
format_type,
@ -4924,6 +4927,7 @@ impl DataFormat {
schema_registry_settings,
subject,
designated_timestamp_policy,
external_diff_column_index,
}
}
@ -6124,6 +6128,7 @@ impl DataFormat {
self.table_name()?,
key_field_names,
self.value_field_names(py),
self.external_diff_column_index,
);
match maybe_formatter {
Ok(formatter) => Ok(Box::new(formatter)),

View File

@ -20,6 +20,7 @@ fn test_psql_format_snapshot_commands() -> eyre::Result<()> {
"value_bool".to_string(),
"value_float".to_string(),
],
None,
)?;
let result = formatter.format(
@ -73,6 +74,7 @@ fn test_psql_primary_key_unspecified() -> eyre::Result<()> {
"value_bool".to_string(),
"value_float".to_string(),
],
None,
);
assert_matches!(formatter, Err(PsqlSnapshotFormatterError::UnknownKey(key)) if key == "key");
Ok(())
@ -89,6 +91,7 @@ fn test_psql_format_snapshot_composite() -> eyre::Result<()> {
"value_bool".to_string(),
"value_float".to_string(),
],
None,
);
let result = formatter?.format(
@ -112,3 +115,58 @@ fn test_psql_format_snapshot_composite() -> eyre::Result<()> {
Ok(())
}
#[test]
fn test_psql_format_external_diff() -> eyre::Result<()> {
let mut formatter = PsqlSnapshotFormatter::new(
"table_name".to_string(),
vec!["key".to_string()],
vec![
"key".to_string(),
"external_diff".to_string(),
"value_bool".to_string(),
"value_float".to_string(),
],
Some(1),
)?;
let result = formatter.format(
&Key::for_value(&Value::from("1")),
&[
Value::from("k"),
Value::from(1),
Value::Bool(true),
Value::from(1.23),
],
Timestamp(5),
1,
)?;
assert_eq!(result.payloads.len(), 1);
assert_document_raw_byte_contents(
&result.payloads[0],
b"INSERT INTO table_name (key,external_diff,value_bool,value_float,time,diff) VALUES ($1,$2,$3,$4,5,1) ON CONFLICT (key) DO UPDATE SET external_diff=$2,value_bool=$3,value_float=$4,time=5,diff=1 WHERE table_name.key=$1\n"
);
assert_eq!(result.values.len(), 4);
let result = formatter.format(
&Key::for_value(&Value::from("1")),
&[
Value::from("k"),
Value::from(-1),
Value::Bool(true),
Value::from(1.23),
],
Timestamp(5),
1, // even though diff=1, external_diff_column that is `Value::from(-1)` will be used
)?;
assert_eq!(result.payloads.len(), 1);
assert_document_raw_byte_contents(
&result.payloads[0],
b"DELETE FROM table_name WHERE key=$1\n",
);
assert_eq!(result.values.len(), 1);
Ok(())
}