Skip to content

Polars Generator

Generate Polars DataFrame validators from Flycatcher schemas.

create_polars_validator(schema_cls)

Create a Polars validator from a Schema class.

Parameters:

Name Type Description Default
schema_cls type

A subclass of Schema.

required

Returns:

Type Description
PolarsValidator

An instance of PolarsValidator for the given schema.

Source code in src/flycatcher/generators/polars.py
def create_polars_validator(schema_cls: "type[Schema]") -> PolarsValidator:
    """
    Create a Polars validator from a Schema class.

    Parameters
    ----------
    schema_cls : type
        A subclass of Schema.

    Returns
    -------
    PolarsValidator
        An instance of PolarsValidator for the given schema.
    """
    return PolarsValidator(schema_cls)

PolarsValidator(schema_cls)

A validator for Polars DataFrames based on schema definition.

This class validates Polars DataFrames against a Flycatcher schema, checking types, constraints, and cross-field validators.

Parameters:

Name Type Description Default
schema_cls type[Schema]

A Flycatcher Schema class to validate against.

required

Examples:

>>> from flycatcher import Field, Schema
>>> import polars as pl
>>> class UserSchema(Schema):
...     id: int = Field(primary_key=True)
...     name: str = Field(min_length=1)
>>> validator = UserSchema.to_polars_validator()
>>> df = pl.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]})
>>> validated_df = validator.validate(df, strict=True)
Source code in src/flycatcher/generators/polars.py
def __init__(self, schema_cls: "type[Schema]") -> None:
    self.schema_cls = schema_cls
    self.fields = schema_cls.fields()
    self._polars_schema = self._build_polars_schema()
    self._constraints = self._build_constraints()

schema property

Return the Polars schema dict.

validate(df, strict=True, show_violations=False, fill_nulls=False)

Validate and coerce a DataFrame to match the schema.

Parameters:

Name Type Description Default
df DataFrame

Input Polars DataFrame.

required
strict bool

If True, raise on validation errors. If False, filter invalid rows.

True
show_violations bool

If True, show violations in the console.

False
fill_nulls bool

If True, replace null values with field defaults (if specified). Note: This is a transformation step. Defaults only apply to missing columns by default. Enable this to also fill existing null values.

False

Returns:

Type Description
DataFrame

Validated DataFrame with correct types. If fill_nulls=True, null values will be replaced with defaults where applicable.

Raises:

Type Description
ValueError

If validation fails and strict=True.

Notes

Behavior of defaults:

  • Missing columns with defaults are always added to the DataFrame
  • Existing null values are filled with defaults only if fill_nulls=True
  • If a field is nullable without a default, nulls are preserved
Source code in src/flycatcher/generators/polars.py
def validate(
    self,
    df: pl.DataFrame,
    strict: bool = True,
    show_violations: bool = False,
    fill_nulls: bool = False,
) -> pl.DataFrame:
    """
    Validate and coerce a DataFrame to match the schema.

    Parameters
    ----------
    df : pl.DataFrame
        Input Polars DataFrame.
    strict : bool, default True
        If True, raise on validation errors. If False, filter invalid rows.
    show_violations : bool, default False
        If True, show violations in the console.
    fill_nulls : bool, default False
        If True, replace null values with field defaults (if specified).
        Note: This is a transformation step. Defaults only apply to missing
        columns by default. Enable this to also fill existing null values.

    Returns
    -------
    pl.DataFrame
        Validated DataFrame with correct types. If fill_nulls=True, null
        values will be replaced with defaults where applicable.

    Raises
    ------
    ValueError
        If validation fails and strict=True.

    Notes
    -----
    Behavior of defaults:

    - Missing columns with defaults are always added to the DataFrame
    - Existing null values are filled with defaults only if fill_nulls=True
    - If a field is nullable without a default, nulls are preserved
    """
    # Check for missing required columns (no default value = required)
    required_cols = {
        name for name, field in self.fields.items() if field.default is _MISSING
    }
    missing = required_cols - set(df.columns)
    if missing:
        raise ValueError(f"Missing required columns: {missing}")

    # Add missing columns with default values
    missing_with_defaults = []
    for field_name, field in self.fields.items():
        if field_name not in df.columns and field.default is not _MISSING:
            missing_with_defaults.append((field_name, field.default, field))

    if missing_with_defaults:
        for field_name, default_value, field in missing_with_defaults:
            dtype = field.get_polars_dtype()
            df = df.with_columns(
                pl.lit(default_value).cast(dtype).alias(field_name)
            )
            logger.info(
                f"Added column '{field_name}' with default value: {default_value}"
            )

    # Cast to correct types and ensure column order matches schema
    cast_exprs = []
    for col_name, dtype in self._polars_schema.items():
        if col_name in df.columns:
            cast_exprs.append(pl.col(col_name).cast(dtype, strict=False))

    if cast_exprs:
        df = df.select(cast_exprs)

    # Handle nulls based on configuration
    for field_name, field in self.fields.items():
        if field_name not in df.columns:
            continue

        null_count = df[field_name].null_count()
        if null_count == 0:
            continue  # No nulls, nothing to do

        # Option 1: Fill nulls with defaults (if enabled and default exists)
        if fill_nulls and field.default is not _MISSING:
            dtype = field.get_polars_dtype()
            df = df.with_columns(
                pl.col(field_name)
                .fill_null(pl.lit(field.default).cast(dtype))
                .alias(field_name)
            )
            logger.info(
                f"Filled {null_count} null values in '{field_name}' "
                f"with default: {field.default}"
            )
            continue

        # Option 2: Validate nulls against nullable constraint
        if not field.nullable:
            if strict:
                raise ValueError(
                    f"Column '{field_name}' has {null_count} null values "
                    f"but is not nullable"
                )
            else:
                # Filter out null rows
                df = df.filter(pl.col(field_name).is_not_null())
        # If nullable=True, nulls are allowed and preserved

    # Apply custom constraints
    violations = []
    for constraint_expr, error_msg in self._constraints:
        # Get rows that violate the constraint
        try:
            invalid_mask = ~constraint_expr
            violation_count = df.filter(invalid_mask).height

            if violation_count > 0:
                if strict:
                    # Show sample of violations
                    sample_violations = df.filter(invalid_mask).head(5)
                    raise ValueError(
                        f"Constraint violation: {error_msg}\n"
                        f"Found {violation_count} violations.\n"
                        f"Sample violations:\n{sample_violations}"
                    )
                else:
                    # Filter out invalid rows
                    violations.append(
                        {
                            "constraint": error_msg,
                            "count": violation_count,
                            "rows": df.filter(invalid_mask).head(10),
                        }
                    )
                    df = df.filter(constraint_expr)
        except ValueError as e:
            # Check if this is our intentional ValueError for constraint violations
            # (has "Constraint violation" in the message) vs an evaluation error
            if "Constraint violation" in str(e) and strict:
                # This is our intentional ValueError - re-raise it
                raise
            # Otherwise, it's an evaluation error - log and continue
            logger.warning(
                f"Could not evaluate constraint '{error_msg}': {e}",
                exc_info=True,
            )
        except Exception as e:
            # Handle cases where constraint can't be evaluated
            logger.warning(
                f"Could not evaluate constraint '{error_msg}': {e}",
                exc_info=True,
            )

    if show_violations:
        for violation in violations:
            logger.warning(f"Constraint violation: {violation['constraint']}")
            logger.warning(f"Count: {violation['count']}")
            logger.warning(f"Rows: {violation['rows']}")
            logger.warning("-" * 80)
    return df

describe_constraints()

Return human-readable list of constraints.

Source code in src/flycatcher/generators/polars.py
def describe_constraints(self) -> List[str]:
    """Return human-readable list of constraints."""
    return [msg for _, msg in self._constraints]

Usage

The create_polars_validator function is typically called via the Schema.to_polars_validator() method:

from flycatcher import Schema, Field
import polars as pl

class UserSchema(Schema):
    id: int = Field(primary_key=True)
    name: str = Field(min_length=1, max_length=100)

# Generate Polars validator
validator = UserSchema.to_polars_validator()

# Validate a DataFrame
df = pl.DataFrame({
    "id": [1, 2, 3],
    "name": ["Alice", "Bob", "Charlie"]
})

validated_df = validator.validate(df, strict=True)

Validation Modes

The PolarsValidator.validate() method supports different validation modes:

  • Strict mode (strict=True): Raises exceptions on validation errors
  • Non-strict mode (strict=False): Filters out invalid rows
  • Show violations (show_violations=True): Prints violation details to console
  • Fill nulls (fill_nulls=True): Replaces null values with field defaults
# Non-strict validation (filter invalid rows)
valid_df = validator.validate(df, strict=False)

# Show violations in console
validator.validate(df, strict=True, show_violations=True)

# Fill nulls with defaults
filled_df = validator.validate(df, fill_nulls=True)