Page Classification

Download Code

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "data-designer>=0.5.6",
# ]
# ///
"""Long-Document Understanding Page Classification Recipe

Classify document page images by their visual element types and reasoning
complexity using a vision-language model. For each seed record the pipeline
produces a structured `page_classification` column containing:

  - `contains_reasoning_content` – whether the page has visual elements
    suitable for reasoning QA
  - `primary_categories` – ordered list of visual element categories
    (QUANTITATIVE, TABULAR, LOGIC_DIAGRAMS, HIERARCHICAL, etc.)
  - `subcategories` – specific element types (BAR_CHART, FLOWCHART, …)
  - `reasoning_complexity_score` – 1-10 cognitive demand rating
  - `justification` – brief explanation of the classification

Prerequisites:
    - A seed parquet file containing a `png_images_base64` column with a JSON
      array of base64-encoded PNG images (one element per page; single-page
      seeds have a one-element array).
    - A vLLM-compatible deployment of the VLM
      (default: Qwen/Qwen3-VL-30B-A3B-Instruct).
      Recommended vLLM launch flags:
        --tensor-parallel-size 2
        --max-model-len 128000
        --gpu-memory-utilization 0.95
        --trust-remote-code

      Example launch script for 2× H100:
        docker run --gpus all \
            -p 8000:8000 \
            vllm/vllm-openai:latest \
            --model Qwen/Qwen3-VL-30B-A3B-Instruct \
            --tensor-parallel-size 2 \
            --max-model-len 128000 \
            --gpu-memory-utilization 0.95 \
            --trust-remote-code

Run:
    # Basic usage (classifies 5 pages by default)
    uv run 04-page-classification-sdg.py --vllm-endpoint http://localhost:8000/v1 --seed-path seed_data/seed_per_page.parquet

    # Custom record count
    uv run 04-page-classification-sdg.py --vllm-endpoint http://localhost:8000/v1 --seed-path seed_data/seed_per_page.parquet --num-records 100

    # For help message and available options
    uv run 04-page-classification-sdg.py --help
"""

from enum import Enum
from pathlib import Path

from pydantic import BaseModel, Field

import data_designer.config as dd
from data_designer.interface import DataDesigner, DatasetCreationResults

DEFAULT_VLM_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
VLLM_PROVIDER_NAME = "vllm"

# =============================================================================
# Structured output schema
# =============================================================================


class VisualElementCategory(str, Enum):
    QUANTITATIVE = "QUANTITATIVE"
    LOGIC_DIAGRAMS = "LOGIC_DIAGRAMS"
    HIERARCHICAL = "HIERARCHICAL"
    SPATIAL_RELATIONAL = "SPATIAL_RELATIONAL"
    SCHEMATIC = "SCHEMATIC"
    TABULAR = "TABULAR"
    INFOGRAPHIC = "INFOGRAPHIC"
    NONE = "NONE"


class VisualElementSubcategory(str, Enum):
    # QUANTITATIVE
    BAR_CHART = "BAR_CHART"
    LINE_GRAPH = "LINE_GRAPH"
    SCATTER_PLOT = "SCATTER_PLOT"
    PIE_CHART = "PIE_CHART"
    AREA_GRAPH = "AREA_GRAPH"
    HISTOGRAM = "HISTOGRAM"
    BOX_PLOT = "BOX_PLOT"
    HEATMAP = "HEATMAP"
    BUBBLE_CHART = "BUBBLE_CHART"
    # LOGIC_DIAGRAMS
    FLOWCHART = "FLOWCHART"
    DECISION_TREE = "DECISION_TREE"
    PROCESS_MAP = "PROCESS_MAP"
    ALGORITHM_DIAGRAM = "ALGORITHM_DIAGRAM"
    STATE_DIAGRAM = "STATE_DIAGRAM"
    SEQUENCE_DIAGRAM = "SEQUENCE_DIAGRAM"
    # HIERARCHICAL
    ORG_CHART = "ORG_CHART"
    MIND_MAP = "MIND_MAP"
    TREE_STRUCTURE = "TREE_STRUCTURE"
    TAXONOMY = "TAXONOMY"
    DENDROGRAM = "DENDROGRAM"
    # SPATIAL_RELATIONAL
    FLOOR_PLAN = "FLOOR_PLAN"
    BLUEPRINT = "BLUEPRINT"
    CHOROPLETH_MAP = "CHOROPLETH_MAP"
    POINT_MAP = "POINT_MAP"
    TOPOGRAPHIC_MAP = "TOPOGRAPHIC_MAP"
    NETWORK_DIAGRAM = "NETWORK_DIAGRAM"
    # SCHEMATIC
    CIRCUIT_DIAGRAM = "CIRCUIT_DIAGRAM"
    MECHANICAL_DIAGRAM = "MECHANICAL_DIAGRAM"
    ANATOMICAL_DIAGRAM = "ANATOMICAL_DIAGRAM"
    WIRING_DIAGRAM = "WIRING_DIAGRAM"
    PLUMBING_DIAGRAM = "PLUMBING_DIAGRAM"
    # TABULAR
    SIMPLE_TABLE = "SIMPLE_TABLE"
    NESTED_TABLE = "NESTED_TABLE"
    PIVOT_TABLE = "PIVOT_TABLE"
    COMPARISON_TABLE = "COMPARISON_TABLE"
    FINANCIAL_TABLE = "FINANCIAL_TABLE"
    # INFOGRAPHIC
    TIMELINE = "TIMELINE"
    STATISTICAL_INFOGRAPHIC = "STATISTICAL_INFOGRAPHIC"
    PROCESS_INFOGRAPHIC = "PROCESS_INFOGRAPHIC"
    COMPARISON_INFOGRAPHIC = "COMPARISON_INFOGRAPHIC"
    # NONE
    DECORATIVE_IMAGE = "DECORATIVE_IMAGE"
    PHOTOGRAPH = "PHOTOGRAPH"
    PLAIN_TEXT = "PLAIN_TEXT"
    GENERIC_ICON = "GENERIC_ICON"
    OTHER = "OTHER"


class PageClassification(BaseModel):
    """Classification result for a document page's reasoning potential."""

    contains_reasoning_content: bool = Field(
        ...,
        description=(
            "Whether the page contains visual elements suitable for reasoning QA pairs. "
            "Must be False if primary_categories contains NONE. "
            "Must be True if primary_categories does NOT contain NONE."
        ),
    )
    primary_categories: list[VisualElementCategory] = Field(
        ...,
        description=(
            "List of visual element categories found in the page, ordered by prominence. "
            "IMPORTANT: If NONE is present, it must be the ONLY category in this list."
        ),
    )
    subcategories: list[VisualElementSubcategory] = Field(
        ...,
        description="Specific types of visual elements identified (e.g., BAR_CHART, FLOWCHART).",
    )
    reasoning_complexity_score: int = Field(
        ...,
        ge=1,
        le=10,
        description="Complexity score from 1-10 indicating the depth of reasoning required.",
    )
    justification: str = Field(
        ...,
        description="Brief explanation of why this page is or isn't suitable for reasoning QA generation.",
    )


# =============================================================================
# Prompt template
# =============================================================================

PROMPT_PAGE_CLASSIFICATION = """\
# ROLE AND OBJECTIVE
You are a document intelligence analyst specializing in visual reasoning assessment. Your task is to analyze document page images and determine their suitability for generating high-quality reasoning-based Question-Answer (QA) pairs.

# CLASSIFICATION TAXONOMY
Identify and classify ALL visual elements present in the image using these categories:

**QUANTITATIVE** - Data visualizations requiring numerical analysis
  • Bar charts, line graphs, scatter plots, pie charts, area graphs
  • Requires: trend analysis, value comparison, rate calculations

**LOGIC_DIAGRAMS** - Process and decision flows
  • Flowcharts, decision trees, process maps, algorithmic diagrams
  • Requires: conditional reasoning, path tracing, outcome prediction

**HIERARCHICAL** - Organizational and structural relationships
  • Organizational charts, mind maps, tree structures, taxonomies
  • Requires: understanding parent-child relationships, levels, dependencies

**SPATIAL_RELATIONAL** - Geographic and spatial layouts
  • Floor plans, blueprints, maps (choropleth, point, topographic)
  • Requires: distance estimation, position inference, spatial reasoning

**SCHEMATIC** - Technical diagrams with component relationships
  • Circuit diagrams, mechanical cross-sections, anatomical diagrams with labels
  • Requires: understanding connections, tracing signal/flow paths, component identification

**TABULAR** - Structured data in rows and columns
  • Tables with nested headers, merged cells, subtotals, calculated rows
  • Requires: cross-referencing values, performing calculations, identifying patterns

**INFOGRAPHIC** - Multi-modal composite narratives
  • Mixed visuals combining charts, text, icons, and data into cohesive stories
  • Requires: synthesizing information across multiple elements

**NONE** - Content without reasoning potential
  • Decorative images, simple photographs, plain text blocks, generic icons
  • Presentation slides with only text or bullet points (no visual elements)
  • No data relationships, calculations, or logical deductions possible

**Note on Presentation Content**: The format (e.g., presentation slide, document page) doesn't matter.
Classify based on the actual visual elements present:
  • Slide with bar chart → QUANTITATIVE
  • Slide with flowchart → LOGIC_DIAGRAMS
  • Slide with only text/bullets → NONE

# REASONING COMPLEXITY ASSESSMENT
Score pages 1-10 based on cognitive demand:

**High Complexity (8-10)**: Requires multi-step inference
  • Cross-referencing multiple data sources
  • Mathematical derivation (growth rates, percentages, trends)
  • Conditional logic chains (if-then-else reasoning)
  • Spatial or temporal reasoning across disconnected components

**Medium Complexity (4-7)**: Requires single-step analysis
  • Direct comparisons between values
  • Simple calculations from visible data
  • Following a single logical path
  • Identifying explicit patterns or relationships

**Low Complexity (1-3)**: Minimal reasoning
  • Direct lookup of visible information
  • Simple identification tasks
  • No relationships or calculations needed

# EVALUATION PROCESS
1. **Scan for visual elements**: Identify all charts, diagrams, tables, or structured content
2. **Classify elements**: Assign primary categories (up to 3, ordered by prominence)
3. **Identify subcategories**: Determine specific visual element types (e.g., BAR_CHART, FLOWCHART)
4. **Assess reasoning depth**: Determine if multi-step thinking is necessary
5. **Score complexity**: Rate 1-10 based on cognitive requirements
6. **Justify classification**: Explain why this page is or isn't suitable for reasoning QA

# DECISION CRITERIA

**CRITICAL RULES**:
1. If `primary_categories` contains NONE, it must be the ONLY category (do NOT mix NONE with other categories)
2. `contains_reasoning_content` must be **False** if NONE is present
3. If content has reasoning elements, do NOT include NONE at all
4. Ignore presentation format - classify by actual visual content (slide with chart = QUANTITATIVE, not PRESENTATION)

Mark `contains_reasoning_content: true` ONLY if:
  ✓ Primary categories does NOT contain NONE, AND
  ✓ At least one of these reasoning elements is present:
    - Quantitative comparisons possible (e.g., "Which region had highest growth?")
    - Logical paths to trace (e.g., "What happens if condition X fails?")
    - Mathematical derivations needed (e.g., "Calculate percentage change")
    - Spatial/temporal relationships to deduce
    - Complex table requiring cross-referencing

Mark `contains_reasoning_content: false` if:
  ✗ Only decorative or generic imagery → set primary_categories: ["NONE"]
  ✗ Plain text with no visual structure → set primary_categories: ["NONE"]
  ✗ Simple lists or single-column tables → set primary_categories: ["NONE"]
  ✗ Slides with only text/bullet points (no charts/diagrams) → set primary_categories: ["NONE"]
  ✗ No data relationships to explore → set primary_categories: ["NONE"]

**Classification Logic**:
- Either the page has reasoning content (assign specific categories like QUANTITATIVE, TABULAR, etc.)
- OR it doesn't (assign only ["NONE"])
- NEVER mix NONE with other categories

# SUBCATEGORIES
For each visual element found, identify the specific subcategory (e.g., BAR_CHART, FLOWCHART, FLOOR_PLAN).
Include the most prominent subcategories in the `subcategories` list, ordered by importance.\
"""


# =============================================================================
# Pipeline configuration
# =============================================================================


def build_config(
    seed_path: str = "seed.parquet",
    model_alias: str = "qwen-vl",
    model_id: str = DEFAULT_VLM_MODEL,
) -> dd.DataDesignerConfigBuilder:
    model_configs = [
        dd.ModelConfig(
            alias=model_alias,
            model=model_id,
            provider=VLLM_PROVIDER_NAME,
            inference_parameters=dd.ChatCompletionInferenceParams(
                timeout=1200,
                max_tokens=100000,
                max_parallel_requests=32,
            ),
        ),
    ]

    config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs)

    config_builder.with_seed_dataset(
        dd.LocalFileSeedSource(path=seed_path),
        sampling_strategy=dd.SamplingStrategy.ORDERED,
    )

    config_builder.add_column(
        dd.LLMStructuredColumnConfig(
            name="page_classification",
            model_alias=model_alias,
            prompt=PROMPT_PAGE_CLASSIFICATION,
            output_format=PageClassification,
            multi_modal_context=[
                dd.ImageContext(
                    # Expects a single-element JSON array from the per-page seed.
                    column_name="png_images_base64",
                    data_type=dd.ModalityDataType.BASE64,
                    image_format=dd.ImageFormat.PNG,
                ),
            ],
        )
    )

    return config_builder


def create_dataset(
    config_builder: dd.DataDesignerConfigBuilder,
    num_records: int,
    vllm_endpoint: str,
    artifact_path: Path | str | None = None,
) -> DatasetCreationResults:
    model_providers = [
        dd.ModelProvider(
            name=VLLM_PROVIDER_NAME,
            endpoint=vllm_endpoint,
        ),
    ]
    data_designer = DataDesigner(
        artifact_path=artifact_path,
        model_providers=model_providers,
    )
    data_designer.set_run_config(dd.RunConfig(progress_bar=True, disable_early_shutdown=True))
    results = data_designer.create(config_builder, num_records=num_records, dataset_name="page_classification")
    return results


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument(
        "--vllm-endpoint",
        type=str,
        required=True,
        help="Base URL of the vLLM server hosting the VLM (e.g. http://localhost:8000/v1)",
    )
    parser.add_argument("--seed-path", type=str, required=True, help="Path to the seed parquet file")
    parser.add_argument("--model-alias", type=str, default="qwen-vl")
    parser.add_argument("--model-id", type=str, default=DEFAULT_VLM_MODEL)
    parser.add_argument("--num-records", type=int, default=5)
    parser.add_argument("--artifact-path", type=str, default=None)
    args = parser.parse_args()

    config_builder = build_config(
        seed_path=args.seed_path,
        model_alias=args.model_alias,
        model_id=args.model_id,
    )
    results = create_dataset(
        config_builder,
        num_records=args.num_records,
        vllm_endpoint=args.vllm_endpoint,
        artifact_path=args.artifact_path,
    )

    print(f"Dataset saved to: {results.artifact_storage.final_dataset_path}")

    results.load_analysis().to_report()