support external diff column in postgres snapshot connector (#8979)
GitOrigin-RevId: e1c87bea13a6a522878122b40c3e460d3295b573
This commit is contained in:
parent
4f465a7597
commit
9e9d8fc4e8
|
@ -409,3 +409,78 @@ def test_psql_json_datetimes(postgres):
|
|||
assert result["a"] == expected
|
||||
assert result["b"] == expected
|
||||
assert result["c"] == expected
|
||||
|
||||
|
||||
def test_psql_external_diff_column(tmp_path, postgres):
|
||||
class InputSchema(pw.Schema):
|
||||
name: str = pw.column_definition(primary_key=True)
|
||||
count: int
|
||||
price: float
|
||||
available: bool
|
||||
external_diff: int
|
||||
|
||||
input_path = tmp_path / "input.txt"
|
||||
output_table = postgres.create_table(InputSchema, used_for_output=True)
|
||||
|
||||
def _run(test_items: list[dict]) -> None:
|
||||
G.clear()
|
||||
with open(input_path, "w") as f:
|
||||
for test_item in test_items:
|
||||
f.write(json.dumps(test_item) + "\n")
|
||||
table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static")
|
||||
pw.io.postgres.write_snapshot(
|
||||
table,
|
||||
POSTGRES_SETTINGS,
|
||||
output_table,
|
||||
["name"],
|
||||
_external_diff_column=table.external_diff,
|
||||
)
|
||||
run()
|
||||
|
||||
test_items = [
|
||||
{
|
||||
"name": "Milk",
|
||||
"count": 500,
|
||||
"price": 1.5,
|
||||
"available": False,
|
||||
"external_diff": 1,
|
||||
},
|
||||
{
|
||||
"name": "Water",
|
||||
"count": 600,
|
||||
"price": 0.5,
|
||||
"available": True,
|
||||
"external_diff": 1,
|
||||
},
|
||||
]
|
||||
_run(test_items)
|
||||
|
||||
rows = postgres.get_table_contents(output_table, InputSchema.column_names())
|
||||
rows.sort(key=lambda item: (item["name"], item["available"]))
|
||||
assert rows == test_items
|
||||
|
||||
# Also test that the junk data in the additional columns would not break deletion
|
||||
new_test_items = [
|
||||
{
|
||||
"name": "Milk",
|
||||
"count": -1,
|
||||
"price": -1.0,
|
||||
"available": True,
|
||||
"external_diff": -1,
|
||||
}
|
||||
]
|
||||
_run(new_test_items)
|
||||
|
||||
rows = postgres.get_table_contents(
|
||||
output_table, InputSchema.column_names(), ("name", "available")
|
||||
)
|
||||
expected_rows = [
|
||||
{
|
||||
"name": "Water",
|
||||
"count": 600,
|
||||
"price": 0.5,
|
||||
"available": True,
|
||||
"external_diff": 1,
|
||||
},
|
||||
]
|
||||
assert rows == expected_rows
|
||||
|
|
|
@ -488,3 +488,14 @@ def _prepare_s3_connection_engine_settings(
|
|||
if aws_s3_settings is None:
|
||||
return None
|
||||
return aws_s3_settings.settings
|
||||
|
||||
|
||||
def get_column_index(table: Table, column: ColumnReference | None) -> int | None:
|
||||
if column is None:
|
||||
return None
|
||||
if column._table != table:
|
||||
raise ValueError(f"The column {column} doesn't belong to the target table")
|
||||
for index, table_column in enumerate(table._columns):
|
||||
if table_column == column.name:
|
||||
return index
|
||||
raise RuntimeError(f"The column {column} is not found in the table {table}")
|
||||
|
|
|
@ -4,12 +4,13 @@ from __future__ import annotations
|
|||
|
||||
from typing import Iterable, Literal
|
||||
|
||||
from pathway.internals import api, datasink
|
||||
from pathway.internals import api, datasink, dtype
|
||||
from pathway.internals._io_helpers import _format_output_value_fields
|
||||
from pathway.internals.expression import ColumnReference
|
||||
from pathway.internals.runtime_type_check import check_arg_types
|
||||
from pathway.internals.table import Table
|
||||
from pathway.internals.trace import trace_user_frame
|
||||
from pathway.io._utils import get_column_index
|
||||
|
||||
|
||||
def _connection_string_from_settings(settings: dict):
|
||||
|
@ -159,6 +160,7 @@ def write_snapshot(
|
|||
init_mode: Literal["default", "create_if_not_exists", "replace"] = "default",
|
||||
name: str | None = None,
|
||||
sort_by: Iterable[ColumnReference] | None = None,
|
||||
_external_diff_column: ColumnReference | None = None,
|
||||
) -> None:
|
||||
"""Maintains a snapshot of a table within a Postgres table.
|
||||
|
||||
|
@ -232,11 +234,19 @@ def write_snapshot(
|
|||
table_name=table_name,
|
||||
sql_writer_init_mode=_init_mode_from_str(init_mode),
|
||||
)
|
||||
|
||||
if (
|
||||
_external_diff_column is not None
|
||||
and _external_diff_column._column.dtype != dtype.INT
|
||||
):
|
||||
raise ValueError("_external_diff_column can only have an integer type")
|
||||
external_diff_column_index = get_column_index(table, _external_diff_column)
|
||||
data_format = api.DataFormat(
|
||||
format_type="sql_snapshot",
|
||||
key_field_names=primary_key,
|
||||
value_fields=_format_output_value_fields(table),
|
||||
table_name=table_name,
|
||||
external_diff_column_index=external_diff_column_index,
|
||||
)
|
||||
|
||||
table.to(
|
||||
|
|
|
@ -9,6 +9,7 @@ from pathway.internals.expression import ColumnReference
|
|||
from pathway.internals.runtime_type_check import check_arg_types
|
||||
from pathway.internals.table import Table
|
||||
from pathway.internals.trace import trace_user_frame
|
||||
from pathway.io._utils import get_column_index
|
||||
|
||||
|
||||
@check_arg_types
|
||||
|
@ -142,13 +143,7 @@ def write(
|
|||
f"designated_timestamp is passed, but designated_timestamp_policy is {designated_timestamp_policy}"
|
||||
)
|
||||
designated_timestamp_policy = "use_column"
|
||||
if designated_timestamp._table != table:
|
||||
raise ValueError(
|
||||
f"The column {designated_timestamp} doesn't belong to the target table"
|
||||
)
|
||||
for index, column in enumerate(table._columns):
|
||||
if column == designated_timestamp.name:
|
||||
designated_timestamp_index = index
|
||||
designated_timestamp_index = get_column_index(table, designated_timestamp)
|
||||
elif designated_timestamp_policy is None:
|
||||
designated_timestamp_policy = "use_now"
|
||||
|
||||
|
|
|
@ -462,6 +462,9 @@ pub enum FormatterError {
|
|||
|
||||
#[error(transparent)]
|
||||
SchemaRepository(#[from] SchemaRepositoryError),
|
||||
|
||||
#[error("incorrect external diff value: {0}")]
|
||||
IncorrectDiffColumnValue(Value),
|
||||
}
|
||||
|
||||
pub trait Formatter: Send {
|
||||
|
@ -1877,6 +1880,7 @@ pub struct PsqlSnapshotFormatter {
|
|||
|
||||
key_field_positions: Vec<usize>,
|
||||
value_field_positions: Vec<usize>,
|
||||
external_diff_column_index: Option<usize>,
|
||||
}
|
||||
|
||||
impl PsqlSnapshotFormatter {
|
||||
|
@ -1884,8 +1888,9 @@ impl PsqlSnapshotFormatter {
|
|||
table_name: String,
|
||||
mut key_field_names: Vec<String>,
|
||||
mut value_field_names: Vec<String>,
|
||||
external_diff_column_index: Option<usize>,
|
||||
) -> Result<PsqlSnapshotFormatter, PsqlSnapshotFormatterError> {
|
||||
let mut field_positions = HashMap::<String, usize>::new();
|
||||
let mut field_positions = HashMap::<String, usize>::with_capacity(value_field_names.len());
|
||||
for (index, field_name) in value_field_names.iter_mut().enumerate() {
|
||||
if field_positions.contains_key(field_name) {
|
||||
return Err(PsqlSnapshotFormatterError::RepeatedValueField(take(
|
||||
|
@ -1914,6 +1919,7 @@ impl PsqlSnapshotFormatter {
|
|||
|
||||
key_field_positions,
|
||||
value_field_positions,
|
||||
external_diff_column_index,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1932,7 +1938,20 @@ impl Formatter for PsqlSnapshotFormatter {
|
|||
|
||||
let mut result = Vec::new();
|
||||
|
||||
if diff == 1 {
|
||||
let effective_diff: isize =
|
||||
if let Some(external_diff_column_index) = self.external_diff_column_index {
|
||||
let value = &values[external_diff_column_index];
|
||||
match value {
|
||||
Value::Int(x) if *x == -1 || *x == 1 => (*x)
|
||||
.try_into()
|
||||
.expect("the values from {-1, 1} must convert into isize"),
|
||||
_ => return Err(FormatterError::IncorrectDiffColumnValue(value.clone())),
|
||||
}
|
||||
} else {
|
||||
diff
|
||||
};
|
||||
|
||||
if effective_diff == 1 {
|
||||
let update_pairs = self
|
||||
.value_field_positions
|
||||
.iter()
|
||||
|
|
|
@ -4686,6 +4686,7 @@ pub struct DataFormat {
|
|||
schema_registry_settings: Option<PySchemaRegistrySettings>,
|
||||
subject: Option<String>,
|
||||
designated_timestamp_policy: Option<String>,
|
||||
external_diff_column_index: Option<usize>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -4889,6 +4890,7 @@ impl DataFormat {
|
|||
schema_registry_settings = None,
|
||||
subject = None,
|
||||
designated_timestamp_policy = None,
|
||||
external_diff_column_index = None,
|
||||
))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
|
@ -4907,6 +4909,7 @@ impl DataFormat {
|
|||
schema_registry_settings: Option<PySchemaRegistrySettings>,
|
||||
subject: Option<String>,
|
||||
designated_timestamp_policy: Option<String>,
|
||||
external_diff_column_index: Option<usize>,
|
||||
) -> Self {
|
||||
DataFormat {
|
||||
format_type,
|
||||
|
@ -4924,6 +4927,7 @@ impl DataFormat {
|
|||
schema_registry_settings,
|
||||
subject,
|
||||
designated_timestamp_policy,
|
||||
external_diff_column_index,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6124,6 +6128,7 @@ impl DataFormat {
|
|||
self.table_name()?,
|
||||
key_field_names,
|
||||
self.value_field_names(py),
|
||||
self.external_diff_column_index,
|
||||
);
|
||||
match maybe_formatter {
|
||||
Ok(formatter) => Ok(Box::new(formatter)),
|
||||
|
|
|
@ -20,6 +20,7 @@ fn test_psql_format_snapshot_commands() -> eyre::Result<()> {
|
|||
"value_bool".to_string(),
|
||||
"value_float".to_string(),
|
||||
],
|
||||
None,
|
||||
)?;
|
||||
|
||||
let result = formatter.format(
|
||||
|
@ -73,6 +74,7 @@ fn test_psql_primary_key_unspecified() -> eyre::Result<()> {
|
|||
"value_bool".to_string(),
|
||||
"value_float".to_string(),
|
||||
],
|
||||
None,
|
||||
);
|
||||
assert_matches!(formatter, Err(PsqlSnapshotFormatterError::UnknownKey(key)) if key == "key");
|
||||
Ok(())
|
||||
|
@ -89,6 +91,7 @@ fn test_psql_format_snapshot_composite() -> eyre::Result<()> {
|
|||
"value_bool".to_string(),
|
||||
"value_float".to_string(),
|
||||
],
|
||||
None,
|
||||
);
|
||||
|
||||
let result = formatter?.format(
|
||||
|
@ -112,3 +115,58 @@ fn test_psql_format_snapshot_composite() -> eyre::Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_psql_format_external_diff() -> eyre::Result<()> {
|
||||
let mut formatter = PsqlSnapshotFormatter::new(
|
||||
"table_name".to_string(),
|
||||
vec!["key".to_string()],
|
||||
vec![
|
||||
"key".to_string(),
|
||||
"external_diff".to_string(),
|
||||
"value_bool".to_string(),
|
||||
"value_float".to_string(),
|
||||
],
|
||||
Some(1),
|
||||
)?;
|
||||
|
||||
let result = formatter.format(
|
||||
&Key::for_value(&Value::from("1")),
|
||||
&[
|
||||
Value::from("k"),
|
||||
Value::from(1),
|
||||
Value::Bool(true),
|
||||
Value::from(1.23),
|
||||
],
|
||||
Timestamp(5),
|
||||
1,
|
||||
)?;
|
||||
|
||||
assert_eq!(result.payloads.len(), 1);
|
||||
assert_document_raw_byte_contents(
|
||||
&result.payloads[0],
|
||||
b"INSERT INTO table_name (key,external_diff,value_bool,value_float,time,diff) VALUES ($1,$2,$3,$4,5,1) ON CONFLICT (key) DO UPDATE SET external_diff=$2,value_bool=$3,value_float=$4,time=5,diff=1 WHERE table_name.key=$1\n"
|
||||
);
|
||||
assert_eq!(result.values.len(), 4);
|
||||
|
||||
let result = formatter.format(
|
||||
&Key::for_value(&Value::from("1")),
|
||||
&[
|
||||
Value::from("k"),
|
||||
Value::from(-1),
|
||||
Value::Bool(true),
|
||||
Value::from(1.23),
|
||||
],
|
||||
Timestamp(5),
|
||||
1, // even though diff=1, external_diff_column that is `Value::from(-1)` will be used
|
||||
)?;
|
||||
|
||||
assert_eq!(result.payloads.len(), 1);
|
||||
assert_document_raw_byte_contents(
|
||||
&result.payloads[0],
|
||||
b"DELETE FROM table_name WHERE key=$1\n",
|
||||
);
|
||||
assert_eq!(result.values.len(), 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue