Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/dve/core_engine/backends/base/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def read_to_py_iterator(
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> Iterator[dict[str, Any]]:
"""Iterate through the contents of the resource, yielding dicts
representing each record.
Expand All @@ -107,6 +108,7 @@ def read_to_entity_type(
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> EntityType:
"""Read to the specified entity type, if supported.

Expand All @@ -116,7 +118,12 @@ def read_to_entity_type(

"""
if entity_name == Iterator[dict[str, Any]]:
return self.read_to_py_iterator(resource, entity_name, schema) # type: ignore
return self.read_to_py_iterator(
resource,
entity_name,
schema, # type: ignore
all_model_fields
)

self.raise_if_not_sensible_file(resource, entity_name)

Expand All @@ -125,7 +132,13 @@ def read_to_entity_type(
except KeyError as err:
raise ReaderLacksEntityTypeSupport(entity_type=entity_type) from err

return reader_func(self, resource, entity_name, schema)
return reader_func(
self,
resource,
entity_name,
schema,
all_model_fields=all_model_fields # type: ignore
)

def add_record_index(self, entity: EntityType, **kwargs) -> EntityType:
"""Add a record index to the entity"""
Expand Down
96 changes: 49 additions & 47 deletions src/dve/core_engine/backends/implementations/duckdb/readers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from duckdb import DuckDBPyConnection, DuckDBPyRelation, StarExpression, read_csv
from pydantic import BaseModel

from dve.core_engine.backends.base.reader import BaseFileReader, read_function
from dve.core_engine.backends.base.reader import read_function
from dve.core_engine.backends.exceptions import EmptyFileError, MessageBearingError
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
duckdb_record_index,
duckdb_write_parquet,
get_duckdb_type_from_annotation,
)
from dve.core_engine.backends.implementations.duckdb.types import SQLType
from dve.core_engine.backends.readers.utilities import check_csv_header_expected
from dve.core_engine.backends.readers.csv import CSVFileReader
from dve.core_engine.backends.utilities import get_polars_type_from_annotation, polars_record_index
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
from dve.core_engine.message import FeedbackMessage
Expand All @@ -27,7 +27,7 @@

@duckdb_record_index
@duckdb_write_parquet
class DuckDBCSVReader(BaseFileReader):
class DuckDBCSVReader(CSVFileReader):
"""A reader for CSV files including the ability to compare the passed model
to the file header, if it exists.

Expand All @@ -47,66 +47,57 @@ def __init__(
quotechar: str = '"',
connection: Optional[DuckDBPyConnection] = None,
field_check: bool = False,
field_check_error_code: Optional[str] = "ExpectedVsActualFieldMismatch",
field_check_error_message: Optional[str] = "The submitted header is missing fields",
field_check_error_code: str = "ExpectedVsActualFieldMismatch",
field_check_error_message: str = "The submitted header is missing fields",
null_empty_strings: bool = False,
**_,
):
self.header = header
self.delim = delim
self.quotechar = quotechar
self._connection = connection if connection else ddb.connect(":memory:")
self.field_check = field_check
self.field_check_error_code = field_check_error_code
self.field_check_error_message = field_check_error_message
self.null_empty_strings = null_empty_strings

super().__init__()

def perform_field_check(
self, resource: URI, entity_name: str, expected_schema: type[BaseModel]
):
"""Check that the header of the CSV aligns with the provided model"""
if not self.header:
raise ValueError("Cannot perform field check without a CSV header")

if missing := check_csv_header_expected(resource, expected_schema, self.delim):
raise MessageBearingError(
"The CSV header doesn't match what is expected",
messages=[
FeedbackMessage(
entity=entity_name,
record={"missing_fields": missing},
failure_type="submission",
error_location="Whole File",
reporting_field="missing_fields",
error_code=self.field_check_error_code,
error_message=f"{self.field_check_error_message}", # pylint: disable=line-too-long
)
],
)
super().__init__(
header=header,
delimiter=delim,
quote_char=quotechar,
field_check=field_check,
field_check_error_code=field_check_error_code,
field_check_error_message=field_check_error_message
)

def read_to_py_iterator(
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> Iterator[dict[str, Any]]:
"""Creates an iterable object of rows as dictionaries"""
yield from self.read_to_relation(resource, entity_name, schema).pl().iter_rows(named=True)
yield from self.read_to_relation(
resource,
entity_name,
schema,
all_model_fields,
).pl().iter_rows(named=True)

@read_function(DuckDBPyRelation)
def read_to_relation( # pylint: disable=unused-argument
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> DuckDBPyRelation:
"""Returns a relation object from the source csv"""
if get_content_length(resource) == 0:
raise EmptyFileError(f"File at {resource} is empty.")

if self.field_check:
self.perform_field_check(resource, entity_name, schema)
self.perform_field_check(resource, entity_name, schema, all_model_fields)

reader_options: dict[str, Any] = {
"header": self.header,
"delimiter": self.delim,
"quotechar": self.quotechar,
"delimiter": self.delimiter,
"quotechar": self.quote_char,
}

ddb_schema: dict[str, SQLType] = {
Expand Down Expand Up @@ -138,19 +129,23 @@ class PolarsToDuckDBCSVReader(DuckDBCSVReader):

@read_function(DuckDBPyRelation)
def read_to_relation( # pylint: disable=unused-argument
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> DuckDBPyRelation:
"""Returns a relation object from the source csv"""
if get_content_length(resource) == 0:
raise EmptyFileError(f"File at {resource} is empty.")

if self.field_check:
self.perform_field_check(resource, entity_name, schema)
self.perform_field_check(resource, entity_name, schema, all_model_fields)

reader_options: dict[str, Any] = {
"has_header": self.header,
"separator": self.delim,
"quote_char": self.quotechar,
"separator": self.delimiter,
"quote_char": self.quote_char,
}

polars_types = {
Expand Down Expand Up @@ -216,10 +211,17 @@ def __init__(

@read_function(DuckDBPyRelation)
def read_to_relation( # pylint: disable=unused-argument
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> DuckDBPyRelation:
entity: DuckDBPyRelation = super().read_to_relation(
resource=resource, entity_name=entity_name, schema=schema
resource=resource,
entity_name=entity_name,
schema=schema,
all_model_fields=all_model_fields
)
entity = entity.select(StarExpression(exclude=[RECORD_INDEX_COLUMN_NAME])).distinct()
no_records = entity.shape[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,22 @@ def __init__(
super().__init__()

def read_to_py_iterator(
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> Iterator[dict[str, Any]]:
"""Creates an iterable object of rows as dictionaries"""
return self.read_to_relation(resource, entity_name, schema).pl().iter_rows(named=True)

@read_function(DuckDBPyRelation)
def read_to_relation( # pylint: disable=unused-argument
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
**_,
) -> DuckDBPyRelation:
"""Returns a relation object from the source json"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def __init__(self, *, connection: Optional[DuckDBPyConnection] = None, **kwargs)
super().__init__(**kwargs)

@read_function(DuckDBPyRelation)
def read_to_relation(self, resource: URI, entity_name: str, schema: type[BaseModel]):
def read_to_relation(
self,
resource: URI,
entity_name: str,
schema: type[BaseModel],
**_,
):
"""Returns a relation object from the source xml"""
if self.xsd_location:
msg = self._run_xmllint(file_uri=resource)
Expand Down
37 changes: 27 additions & 10 deletions src/dve/core_engine/backends/implementations/spark/readers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructType

from dve.core_engine.backends.base.reader import BaseFileReader, read_function
from dve.core_engine.backends.base.reader import read_function
from dve.core_engine.backends.readers.csv import CSVFileReader
from dve.core_engine.backends.exceptions import EmptyFileError
from dve.core_engine.backends.implementations.spark.spark_helpers import (
get_type_from_annotation,
Expand All @@ -21,9 +22,10 @@

@spark_record_index
@spark_write_parquet
class SparkCSVReader(BaseFileReader):
class SparkCSVReader(CSVFileReader):
"""A Spark reader for CSV files."""

# pylint: disable=R0902
def __init__(
self,
*,
Expand All @@ -35,24 +37,35 @@ def __init__(
encoding: str = "utf-8-sig",
null_empty_strings: bool = False,
spark_session: Optional[SparkSession] = None,
field_check: bool = False,
field_check_error_code: str = "ExpectedVsActualFieldMismatch",
field_check_error_message: str = "The submitted header is missing fields",
**_,
) -> None:

self.delimiter = delimiter
self.escape_char = escape_char
self.encoding = encoding
self.quote_char = quote_char
self.header = header
self.multi_line = multi_line
self.null_empty_strings = null_empty_strings
self.spark_session = spark_session if spark_session else SparkSession.builder.getOrCreate() # type: ignore # pylint: disable=C0301

super().__init__()
super().__init__(
delimiter=delimiter,
escape_char=escape_char,
encoding=encoding,
quote_char=quote_char,
header=header,
field_check=field_check,
field_check_error_code=field_check_error_code,
field_check_error_message=field_check_error_message,
)

def read_to_py_iterator(
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> Iterator[dict[URI, Any]]:
df = self.read_to_dataframe(resource, entity_name, schema)
df = self.read_to_dataframe(resource, entity_name, schema, all_model_fields)
yield from (record.asDict(True) for record in df.toLocalIterator())

@read_function(DataFrame)
Expand All @@ -61,11 +74,15 @@ def read_to_dataframe(
resource: URI,
entity_name: EntityName, # pylint: disable=unused-argument
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> DataFrame:
"""Read a CSV file directly to a Spark DataFrame."""
if get_content_length(resource) == 0:
raise EmptyFileError(f"File at {resource} is empty.")

if self.field_check:
self.perform_field_check(resource, entity_name, schema, all_model_fields)

spark_schema: StructType = get_type_from_annotation(schema)
kwargs = {
"sep": self.delimiter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def __init__(
super().__init__()

def read_to_py_iterator(
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> Iterator[dict[URI, Any]]:
df = self.read_to_dataframe(resource, entity_name, schema)
yield from (record.asDict(True) for record in df.toLocalIterator())
Expand All @@ -50,6 +54,7 @@ def read_to_dataframe(
resource: URI,
entity_name: EntityName, # pylint: disable=unused-argument
schema: type[BaseModel],
**_,
) -> DataFrame:
"""Read a JSON file directly to a Spark DataFrame."""
if get_content_length(resource) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def __init__(
self.namespace = namespace

def read_to_py_iterator(
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
all_model_fields: Optional[set[str]] = None,
) -> Iterator[dict[URI, Any]]:
df = self.read_to_dataframe(resource, entity_name, schema)
yield from (record.asDict(True) for record in df.toLocalIterator())
Expand All @@ -115,6 +119,7 @@ def read_to_dataframe(
resource: URI,
entity_name: EntityName, # pylint: disable=unused-argument
schema: type[BaseModel],
**_,
) -> DataFrame:
"""Read an XML file directly to a Spark DataFrame using the Databricks
XML reader package.
Expand Down
Loading
Loading