Serialization with Marshmallow

Advanced Data Serialization and Validation for Flask APIs

Introduction to Serialization

When building APIs, serialization is the process of converting complex data structures (like database objects) into formats that can be easily transmitted over a network, typically JSON. The reverse process, deserialization, converts received data back into application objects. This two-way conversion is essential for APIs that need to both send and receive data.

While Flask-RESTful provides basic serialization through its marshaling system, Marshmallow offers a more powerful, flexible, and feature-rich solution for handling complex serialization and validation needs.

Think of serialization like language translation. Marshmallow acts as a skilled translator, converting between the "language" your application speaks internally (Python objects) and the "language" of APIs (JSON), ensuring nothing gets lost in translation and all communication follows the correct grammar rules.

graph LR A[Python Objects] -->|Serialization| B[JSON/API Data] B -->|Deserialization| A C[Marshmallow] --- A C --- B style C fill:#f9f,stroke:#333,stroke-width:2px

Why Use Marshmallow?

While Flask-RESTful's marshaling system works well for basic APIs, Marshmallow offers several advantages for more complex applications:

Here's a quick comparison of Flask-RESTful's marshaling and Marshmallow:

Feature Flask-RESTful Marshaling Marshmallow
Serialization Yes Yes
Deserialization No (separate RequestParser) Yes (unified)
Validation Basic Advanced
Nested Structures Limited Excellent
ORM Integration Limited Extensive
Error Messages Basic Detailed, customizable
Partial Updates Difficult Built-in support

Getting Started with Marshmallow

Let's set up Marshmallow and create a basic schema:

# Install Marshmallow
pip install marshmallow

# For Flask integration
pip install flask-marshmallow

# For SQLAlchemy integration
pip install marshmallow-sqlalchemy

Now, let's create a simple schema:

from marshmallow import Schema, fields

# Define a schema
class UserSchema(Schema):
    id = fields.Integer(dump_only=True)  # Read-only field
    username = fields.String(required=True)
    email = fields.Email(required=True)
    created_at = fields.DateTime(dump_only=True)
    bio = fields.String()
    
# Create an instance of the schema
user_schema = UserSchema()
users_schema = UserSchema(many=True)  # For serializing collections

# Sample data
user_data = {
    'id': 1,
    'username': 'johndoe',
    'email': 'john@example.com',
    'created_at': '2023-01-15T12:30:45',
    'bio': 'Software developer and tech enthusiast.'
}

# Serialize the data
result = user_schema.dump(user_data)
print(result)  # JSON-compatible dict

# Deserialize and validate data
try:
    user = user_schema.load({'username': 'janedoe', 'email': 'jane@example.com'})
    print(user)  # Valid data
except ValidationError as err:
    print(err.messages)  # Validation errors

This example demonstrates the basic usage of Marshmallow:

  1. Define a schema class that inherits from Schema
  2. Define fields with types and validation rules
  3. Create instances of the schema for single objects and collections
  4. Use dump() to serialize data (Python to JSON)
  5. Use load() to deserialize and validate data (JSON to Python)

Field Types and Validation

Marshmallow provides a rich set of field types and validation options:

Common Field Types

Field Type Python Type Description
fields.String str String field
fields.Integer int Integer field
fields.Float float Floating point field
fields.Boolean bool Boolean field
fields.DateTime datetime.datetime Date and time field
fields.Date datetime.date Date field
fields.Time datetime.time Time field
fields.Decimal decimal.Decimal Precise decimal field
fields.Email str Email field with validation
fields.URL str URL field with validation
fields.Dict dict Dictionary field
fields.List list List field
fields.Nested dict/object Nested object field

Field Options

Fields can be configured with various options:

# Field with options
username = fields.String(
    required=True,                 # Field is required during deserialization
    validate=Length(min=3, max=50),# Apply validation constraints
    error_messages={               # Custom error messages
        'required': 'Username is required',
        'validator_failed': 'Username must be between 3 and 50 characters'
    },
    dump_only=False,               # Field can be both loaded and dumped
    load_only=False,               # Field can be both loaded and dumped
    data_key='user_name',          # Different key name in the data
    attribute='name',              # Different attribute name in the object
    default='guest',               # Default value if not provided
    missing=None                   # Default value during deserialization
)

Common Validators

Marshmallow includes several built-in validators:

from marshmallow.validate import Length, Range, OneOf, Email, URL, Regexp

class UserSchema(Schema):
    username = fields.String(validate=Length(min=3, max=50))
    age = fields.Integer(validate=Range(min=18, max=120))
    role = fields.String(validate=OneOf(['user', 'admin', 'editor']))
    email = fields.String(validate=Email())
    website = fields.String(validate=URL())
    phone = fields.String(validate=Regexp(r'^\d{3}-\d{3}-\d{4}$'))

You can also create custom validators:

from marshmallow import ValidationError

def validate_not_in_blacklist(value):
    blacklist = ['admin', 'root', 'superuser']
    if value.lower() in blacklist:
        raise ValidationError('This username is reserved.')

class UserSchema(Schema):
    username = fields.String(validate=validate_not_in_blacklist)

Nested Schemas and Relationships

Marshmallow excels at handling complex, nested data structures, making it ideal for representing relationships between resources:

Nested Objects

class AddressSchema(Schema):
    street = fields.String(required=True)
    city = fields.String(required=True)
    state = fields.String(required=True)
    zip_code = fields.String(required=True)
    country = fields.String(required=True)

class UserSchema(Schema):
    id = fields.Integer(dump_only=True)
    username = fields.String(required=True)
    email = fields.Email(required=True)
    # Nested schema for the user's address
    address = fields.Nested(AddressSchema)

When serializing, this will produce nested JSON:

{
  "id": 1,
  "username": "johndoe",
  "email": "john@example.com",
  "address": {
    "street": "123 Main St",
    "city": "Anytown",
    "state": "CA",
    "zip_code": "12345",
    "country": "USA"
  }
}

Nested Collections

class CommentSchema(Schema):
    id = fields.Integer(dump_only=True)
    content = fields.String(required=True)
    created_at = fields.DateTime(dump_only=True)
    
class PostSchema(Schema):
    id = fields.Integer(dump_only=True)
    title = fields.String(required=True)
    content = fields.String(required=True)
    created_at = fields.DateTime(dump_only=True)
    # Nested collection of comments
    comments = fields.Nested(CommentSchema, many=True)

This will serialize a post with its comments:

{
  "id": 123,
  "title": "My First Post",
  "content": "Hello, world!",
  "created_at": "2023-01-15T12:30:45",
  "comments": [
    {
      "id": 1,
      "content": "Great post!",
      "created_at": "2023-01-15T14:22:10"
    },
    {
      "id": 2,
      "content": "Thanks for sharing.",
      "created_at": "2023-01-16T09:45:30"
    }
  ]
}

Self-Referential Nested Schemas

For hierarchical data like categories or comments with replies, you can use self-referential schemas:

class CategorySchema(Schema):
    id = fields.Integer(dump_only=True)
    name = fields.String(required=True)
    description = fields.String()
    
    # Self-reference for parent-child relationships
    parent_id = fields.Integer(allow_none=True)
    children = fields.Nested('self', many=True, exclude=('parent_id',))

Customizing Nested Fields

You can customize which fields are included in nested objects:

# Include only specific fields
author = fields.Nested(UserSchema, only=('id', 'username'))

# Exclude specific fields
author = fields.Nested(UserSchema, exclude=('email', 'bio'))

# Conditional nested fields
comments = fields.Nested(
    CommentSchema, 
    many=True,
    # Only include comments if the post has at least one
    dump_default=[],  
    # Only include approved comments
    dump_only_if=lambda obj: [c for c in obj.comments if c.approved]
)

Customizing Serialization and Deserialization

Marshmallow provides hooks for customizing how data is processed during serialization and deserialization:

Method Fields

Method fields let you compute values during serialization:

class UserSchema(Schema):
    id = fields.Integer(dump_only=True)
    username = fields.String(required=True)
    first_name = fields.String(required=True)
    last_name = fields.String(required=True)
    
    # Method field for computed value
    full_name = fields.Method('get_full_name')
    
    def get_full_name(self, obj):
        return f"{obj.first_name} {obj.last_name}"

Function Fields

Similar to method fields, but using a function:

def get_likes_count(obj):
    return len(obj.likes)

class PostSchema(Schema):
    id = fields.Integer(dump_only=True)
    title = fields.String(required=True)
    content = fields.String(required=True)
    
    # Function field
    likes_count = fields.Function(get_likes_count)

Pre-processing and Post-processing Hooks

Hooks allow you to modify data before or after processing:

class UserSchema(Schema):
    id = fields.Integer(dump_only=True)
    username = fields.String(required=True)
    email = fields.Email(required=True)
    password = fields.String(load_only=True, required=True)
    
    # Pre-load hook: process data before validation
    @pre_load
    def lowercase_email(self, data, **kwargs):
        # Convert email to lowercase
        if 'email' in data and data['email']:
            data['email'] = data['email'].lower()
        return data
    
    # Post-load hook: process data after validation
    @post_load
    def make_user(self, data, **kwargs):
        # Convert dict to User object
        return User(**data)
    
    # Pre-dump hook: process object before serialization
    @pre_dump
    def prepare_data(self, obj, **kwargs):
        # If obj is a dict, convert timestamp to datetime
        if isinstance(obj, dict) and 'created_at' in obj and isinstance(obj['created_at'], str):
            obj['created_at'] = datetime.fromisoformat(obj['created_at'])
        return obj
    
    # Post-dump hook: process output after serialization
    @post_dump
    def remove_none_values(self, data, **kwargs):
        # Remove fields with None values
        return {key: value for key, value in data.items() if value is not None}

Data Transformation Example

A common use case is transforming data during serialization or deserialization:

class UserProfileSchema(Schema):
    id = fields.Integer(dump_only=True)
    username = fields.String(required=True)
    
    # Store preferences as JSON in DB, but expose as dict in API
    preferences = fields.Dict()
    
    @pre_load
    def parse_preferences(self, data, **kwargs):
        # If preferences is a JSON string, parse it
        if 'preferences' in data and isinstance(data['preferences'], str):
            try:
                data['preferences'] = json.loads(data['preferences'])
            except json.JSONDecodeError:
                data['preferences'] = {}
        return data
    
    @post_dump
    def format_data(self, data, **kwargs):
        # Add metadata
        data['_type'] = 'user_profile'
        data['_links'] = {
            'self': f"/api/users/{data['id']}",
            'settings': f"/api/users/{data['id']}/settings"
        }
        return data

Integrating with Flask-SQLAlchemy

The marshmallow-sqlalchemy extension simplifies working with SQLAlchemy models:

from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

# Initialize extensions
db = SQLAlchemy(app)
ma = Marshmallow(app)

# Define models
class User(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True, nullable=False)
    email = db.Column(db.String(120), unique=True, nullable=False)
    bio = db.Column(db.Text)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    
    posts = db.relationship('Post', backref='author', lazy=True)
    
class Post(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String(100), nullable=False)
    content = db.Column(db.Text, nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
    
# Define schemas
class UserSchema(ma.SQLAlchemySchema):
    class Meta:
        model = User
        # Fields to include
        fields = ('id', 'username', 'email', 'bio', 'created_at', 'posts')
        
    # Auto-generated fields from the model
    id = ma.auto_field()
    username = ma.auto_field()
    email = ma.auto_field()
    bio = ma.auto_field()
    created_at = ma.auto_field()
    
    # Relationship field
    posts = ma.Nested('PostSchema', many=True, exclude=('author',))
    
class PostSchema(ma.SQLAlchemySchema):
    class Meta:
        model = Post
        # Include all fields
        fields = ('id', 'title', 'content', 'created_at', 'user_id', 'author')
        
    id = ma.auto_field()
    title = ma.auto_field()
    content = ma.auto_field()
    created_at = ma.auto_field()
    user_id = ma.auto_field()
    
    # Nested author field
    author = ma.Nested(UserSchema, only=('id', 'username'))
    
# Create schema instances
user_schema = UserSchema()
users_schema = UserSchema(many=True)
post_schema = PostSchema()
posts_schema = PostSchema(many=True)

This approach offers several advantages:

You can also use the even more automatic SQLAlchemyAutoSchema:

class UserSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = User
        include_fk = True  # Include foreign keys
        include_relationships = True  # Include relationships
        load_instance = True  # Deserialize to model instances
        
class PostSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Post
        include_fk = True
        include_relationships = True
        load_instance = True
        
    # Override the author field for custom behavior
    author = ma.Nested(UserSchema, only=('id', 'username'))

With load_instance=True, the schema's load() method will create model instances directly, making it easy to deserialize API data to database objects.

Using Marshmallow with Flask-RESTful

Marshmallow can be used alongside Flask-RESTful to provide enhanced serialization and validation:

from flask import Flask
from flask_restful import Api, Resource
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
from marshmallow import ValidationError

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

# Initialize extensions
db = SQLAlchemy(app)
ma = Marshmallow(app)
api = Api(app)

# Define models
class User(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True, nullable=False)
    email = db.Column(db.String(120), unique=True, nullable=False)

# Define schemas
class UserSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = User
        load_instance = True

# Create schema instances
user_schema = UserSchema()
users_schema = UserSchema(many=True)

# Define resources
class UserResource(Resource):
    def get(self, user_id):
        user = User.query.get_or_404(user_id)
        return user_schema.dump(user)
    
    def put(self, user_id):
        user = User.query.get_or_404(user_id)
        
        try:
            # Deserialize and validate input data
            updated_user = user_schema.load(request.json, instance=user, partial=True)
            db.session.commit()
            return user_schema.dump(updated_user)
        except ValidationError as err:
            return {'message': 'Validation error', 'errors': err.messages}, 400
    
    def delete(self, user_id):
        user = User.query.get_or_404(user_id)
        db.session.delete(user)
        db.session.commit()
        return '', 204

class UserListResource(Resource):
    def get(self):
        users = User.query.all()
        return users_schema.dump(users)
    
    def post(self):
        try:
            # Create a new user instance from the data
            user = user_schema.load(request.json)
            db.session.add(user)
            db.session.commit()
            return user_schema.dump(user), 201
        except ValidationError as err:
            return {'message': 'Validation error', 'errors': err.messages}, 400

# Register resources
api.add_resource(UserResource, '/users/<int:user_id>')
api.add_resource(UserListResource, '/users')

The key differences from using Flask-RESTful's marshaling:

Advanced Features

Field Masking

Allow clients to request only the fields they need:

@app.route('/users/<int:user_id>')
def get_user(user_id):
    user = User.query.get_or_404(user_id)
    
    # Get requested fields from query parameter
    fields = request.args.get('fields')
    if fields:
        # Create a schema with only the requested fields
        only = fields.split(',')
        schema = UserSchema(only=only)
    else:
        schema = user_schema
    
    return schema.dump(user)

Schema Inheritance

Create hierarchies of schemas to reuse field definitions:

class BasePersonSchema(Schema):
    id = fields.Integer(dump_only=True)
    first_name = fields.String(required=True)
    last_name = fields.String(required=True)
    email = fields.Email(required=True)

class UserSchema(BasePersonSchema):
    username = fields.String(required=True)
    password = fields.String(load_only=True, required=True)
    role = fields.String(validate=OneOf(['user', 'admin']))

class EmployeeSchema(BasePersonSchema):
    department = fields.String(required=True)
    hire_date = fields.Date(required=True)
    salary = fields.Decimal(places=2, load_only=True)

Context-dependent Serialization

Pass context to schemas to customize behavior:

class PostSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Post
    
    # Field that changes based on context
    content = fields.Method('get_content')
    
    def get_content(self, obj):
        # Get user role from context
        user_role = self.context.get('user_role', 'guest')
        
        # Show full content to authorized users
        if user_role in ['admin', 'editor'] or obj.author_id == self.context.get('user_id'):
            return obj.content
        
        # Show preview to others
        preview_length = self.context.get('preview_length', 100)
        if len(obj.content) > preview_length:
            return obj.content[:preview_length] + '...'
        return obj.content

# Using context
admin_post = post_schema.dump(post, context={'user_role': 'admin'})
user_post = post_schema.dump(post, context={'user_role': 'user', 'user_id': current_user.id})
guest_post = post_schema.dump(post, context={'user_role': 'guest', 'preview_length': 50})

Schema Registries

For applications with many schemas, a registry can help manage them:

# Create a schema registry
schemas = {
    'user': UserSchema(),
    'users': UserSchema(many=True),
    'post': PostSchema(),
    'posts': PostSchema(many=True),
    'comment': CommentSchema(),
    'comments': CommentSchema(many=True),
    'user_profile': UserProfileSchema(),
    'category': CategorySchema(),
    'categories': CategorySchema(many=True)
}

def get_schema(name, **kwargs):
    """Get a schema by name with optional context."""
    schema = schemas.get(name)
    if not schema:
        raise ValueError(f"Unknown schema: {name}")
    
    # Create a copy with context if provided
    if kwargs:
        schema = schema.__class__(**{**schema.__dict__})
        schema.context.update(kwargs)
    
    return schema

# Using the registry
user_data = get_schema('user', user_role='admin').dump(user)
posts_data = get_schema('posts', preview_length=200).dump(posts)

Error Handling and Validation

Proper error handling is crucial for API usability. Marshmallow provides detailed validation errors:

from marshmallow import ValidationError
from flask import jsonify, request

# Global error handler for validation errors
@app.errorhandler(ValidationError)
def handle_validation_error(err):
    return jsonify({
        'error': 'Validation error',
        'messages': err.messages
    }), 400

@app.route('/users', methods=['POST'])
def create_user():
    try:
        # Deserialize and validate the request data
        user = user_schema.load(request.json)
        db.session.add(user)
        db.session.commit()
        return user_schema.dump(user), 201
    except ValidationError as err:
        # This will be handled by the global error handler
        raise err

A validation error response might look like:

{
  "error": "Validation error",
  "messages": {
    "email": ["Not a valid email address."],
    "username": ["Missing data for required field."],
    "age": ["Must be greater than or equal to 18."]
  }
}

Custom Error Messages

You can customize error messages for fields and validators:

class UserSchema(Schema):
    username = fields.String(
        required=True,
        validate=Length(min=3, max=50),
        error_messages={
            'required': 'Please provide a username',
            'invalid': 'Not a valid string',
            'validator_failed': 'Username must be between 3 and 50 characters'
        }
    )
    email = fields.Email(
        required=True,
        error_messages={
            'required': 'Email address is required',
            'invalid': 'Please provide a valid email address'
        }
    )

Field-level Validation

Add validation methods for specific fields:

class UserSchema(Schema):
    username = fields.String(required=True)
    email = fields.Email(required=True)
    password = fields.String(required=True, load_only=True)
    password_confirm = fields.String(required=True, load_only=True)
    
    # Field-level validator
    def validate_username(self, value):
        # Check if username is already taken
        existing = User.query.filter_by(username=value).first()
        if existing:
            raise ValidationError('This username is already taken')
        return value
    
    # Validate password strength
    def validate_password(self, value):
        if len(value) < 8:
            raise ValidationError('Password must be at least 8 characters')
        if not any(c.isupper() for c in value):
            raise ValidationError('Password must contain at least one uppercase letter')
        if not any(c.isdigit() for c in value):
            raise ValidationError('Password must contain at least one number')
        return value
    
    # Schema-level validation
    @validates_schema
    def validate_passwords_match(self, data, **kwargs):
        if data.get('password') != data.get('password_confirm'):
            raise ValidationError('Passwords must match', 'password_confirm')

Schema-level Validation

Validate multiple fields together:

class EventSchema(Schema):
    title = fields.String(required=True)
    start_date = fields.Date(required=True)
    end_date = fields.Date(required=True)
    
    @validates_schema
    def validate_dates(self, data, **kwargs):
        if 'start_date' in data and 'end_date' in data:
            if data['end_date'] < data['start_date']:
                raise ValidationError('End date must be after start date', 'end_date')

Complete Example: Blog API with Marshmallow

Let's put it all together with a complete example of a blog API using Flask, SQLAlchemy, and Marshmallow:

from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
from marshmallow import ValidationError, validates, validates_schema
from datetime import datetime

# Initialize app
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///blog.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

# Initialize extensions
db = SQLAlchemy(app)
ma = Marshmallow(app)

# Define models
class User(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(50), unique=True, nullable=False)
    email = db.Column(db.String(100), unique=True, nullable=False)
    password = db.Column(db.String(100), nullable=False)  # Would be hashed in real app
    bio = db.Column(db.Text)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    
    posts = db.relationship('Post', backref='author', lazy=True, cascade='all, delete-orphan')
    comments = db.relationship('Comment', backref='author', lazy=True, cascade='all, delete-orphan')

class Category(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(50), unique=True, nullable=False)
    description = db.Column(db.Text)
    
    posts = db.relationship('Post', backref='category', lazy=True)

class Post(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String(100), nullable=False)
    content = db.Column(db.Text, nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
    published = db.Column(db.Boolean, default=False)
    
    user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
    category_id = db.Column(db.Integer, db.ForeignKey('category.id'), nullable=True)
    
    comments = db.relationship('Comment', backref='post', lazy=True, cascade='all, delete-orphan')

class Comment(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    content = db.Column(db.Text, nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    
    user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
    post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False)

# Create the database tables
with app.app_context():
    db.create_all()

# Define schemas
class UserSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = User
        exclude = ('password',)  # Don't include password in serialization
        load_instance = True  # Deserialize to model instances
    
    # Add validation
    @validates('username')
    def validate_username(self, value):
        if len(value) < 3:
            raise ValidationError('Username must be at least 3 characters')
        return value

class CategorySchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Category
        load_instance = True

class CommentSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Comment
        include_fk = True
        load_instance = True
    
    # Add author information
    class PostSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Post
        include_fk = True
        load_instance = True
    
    # Add nested fields
    author = ma.Nested(UserSchema, only= ('id', 'username'))
    category = ma.Nested(CategorySchema, only= ('id', 'name'))
    comments = ma.Nested(CommentSchema, many=True, exclude=('post',))
    
    # Add computed fields
    comment_count = ma.Function(lambda obj: len(obj.comments))
    
    # Validation
    @validates('title')
    def validate_title(self, value):
        if len(value) < 5:
            raise ValidationError('Title must be at least 5 characters')
        return value
# Schema instances
user_schema = UserSchema()
users_schema = UserSchema(many=True)
category_schema = CategorySchema()
categories_schema = CategorySchema(many=True)
post_schema = PostSchema()
posts_schema = PostSchema(many=True)
comment_schema = CommentSchema()
comments_schema = CommentSchema(many=True)

# Error handler
@app.errorhandler(ValidationError)
def handle_validation_error(err):
    return jsonify({
        'error': 'Validation error',
        'messages': err.messages
    }), 400

# User endpoints
@app.route('/users', methods=['GET'])
def get_users():
    users = User.query.all()
    return jsonify(users_schema.dump(users))

@app.route('/users/<int:user_id>', methods=['GET'])
def get_user(user_id):
    user = User.query.get_or_404(user_id)
    return jsonify(user_schema.dump(user))

@app.route('/users', methods=['POST'])
def create_user():
    try:
        user = user_schema.load(request.json)
        db.session.add(user)
        db.session.commit()
        return jsonify(user_schema.dump(user)), 201
    except ValidationError as err:
        return jsonify({
            'error': 'Validation error',
            'messages': err.messages
        }), 400

# Category endpoints
@app.route('/categories', methods=['GET'])
def get_categories():
    categories = Category.query.all()
    return jsonify(categories_schema.dump(categories))

@app.route('/categories/<int:category_id>', methods=['GET'])
def get_category(category_id):
    category = Category.query.get_or_404(category_id)
    return jsonify(category_schema.dump(category))

@app.route('/categories', methods=['POST'])
def create_category():
    try:
        category = category_schema.load(request.json)
        db.session.add(category)
        db.session.commit()
        return jsonify(category_schema.dump(category)), 201
    except ValidationError as err:
        raise err

# Post endpoints
@app.route('/posts', methods=['GET'])
def get_posts():
    # Add filtering
    category_id = request.args.get('category_id', type=int)
    user_id = request.args.get('user_id', type=int)
    published = request.args.get('published')
    
    # Start with base query
    query = Post.query
    
    # Apply filters
    if category_id:
        query = query.filter_by(category_id=category_id)
    if user_id:
        query = query.filter_by(user_id=user_id)
    if published is not None:
        is_published = published.lower() == 'true'
        query = query.filter_by(published=is_published)
    
    # Execute query and serialize
    posts = query.order_by(Post.created_at.desc()).all()
    result = posts_schema.dump(posts)
    
    return jsonify(result)

@app.route('/posts/<int:post_id>', methods=['GET'])
def get_post(post_id):
    post = Post.query.get_or_404(post_id)
    return jsonify(post_schema.dump(post))

@app.route('/posts', methods=['POST'])
def create_post():
    try:
        post = post_schema.load(request.json)
        db.session.add(post)
        db.session.commit()
        return jsonify(post_schema.dump(post)), 201
    except ValidationError as err:
        raise err

@app.route('/posts/<int:post_id>', methods=['PUT'])
def update_post(post_id):
    post = Post.query.get_or_404(post_id)
    
    try:
        # Update existing instance with partial data
        post = post_schema.load(request.json, instance=post, partial=True)
        db.session.commit()
        return jsonify(post_schema.dump(post))
    except ValidationError as err:
        raise err

@app.route('/posts/<int:post_id>', methods=['DELETE'])
def delete_post(post_id):
    post = Post.query.get_or_404(post_id)
    db.session.delete(post)
    db.session.commit()
    return '', 204

# Comment endpoints
@app.route('/posts/<int:post_id>/comments', methods=['GET'])
def get_post_comments(post_id):
    post = Post.query.get_or_404(post_id)
    return jsonify(comments_schema.dump(post.comments))

@app.route('/posts/<int:post_id>/comments', methods=['POST'])
def add_comment(post_id):
    post = Post.query.get_or_404(post_id)
    
    try:
        # Create comment linked to the post
        data = request.json
        data['post_id'] = post_id
        
        comment = comment_schema.load(data)
        db.session.add(comment)
        db.session.commit()
        return jsonify(comment_schema.dump(comment)), 201
    except ValidationError as err:
        raise err

if __name__ == '__main__':
    app.run(debug=True)

This example demonstrates a complete blog API with:

Practical Activity: Task Manager API with Marshmallow

Let's apply what we've learned by enhancing the task manager API from the previous lecture with Marshmallow:

Here's the outline:

  1. Set up Flask, SQLAlchemy, and Marshmallow
  2. Create Task, Category, and User models
  3. Define schemas for each model with appropriate validation
  4. Implement API endpoints with CRUD operations
  5. Add filtering, sorting, and pagination
  6. Implement proper error handling

Start by creating the following files:

task_manager/
├── app.py            # Main application file
├── models.py         # Database models
├── schemas.py        # Marshmallow schemas
├── api.py            # API routes
└── requirements.txt  # Dependencies

Implementation steps:

  1. Define Task, Category, and User models with appropriate relationships
  2. Create schemas for each model with validation rules
  3. Implement API endpoints for all models
  4. Add validation and error handling
  5. Implement filtering, sorting, and pagination
  6. Test the API endpoints

This activity will help you practice using Marshmallow for serialization, deserialization, and validation in a real-world API.

Key Takeaways

Further Learning Resources