Introduction to Serialization
When building APIs, serialization is the process of converting complex data structures (like database objects) into formats that can be easily transmitted over a network, typically JSON. The reverse process, deserialization, converts received data back into application objects. This two-way conversion is essential for APIs that need to both send and receive data.
While Flask-RESTful provides basic serialization through its marshaling system, Marshmallow offers a more powerful, flexible, and feature-rich solution for handling complex serialization and validation needs.
Think of serialization like language translation. Marshmallow acts as a skilled translator, converting between the "language" your application speaks internally (Python objects) and the "language" of APIs (JSON), ensuring nothing gets lost in translation and all communication follows the correct grammar rules.
Why Use Marshmallow?
While Flask-RESTful's marshaling system works well for basic APIs, Marshmallow offers several advantages for more complex applications:
- Two-way conversion: Both serialization (dump) and deserialization (load) in one library
- Complex validation: More sophisticated validation rules and error handling
- Schema reuse: Reuse schemas across different parts of your application
- Nested schemas: Better handling of complex, nested data structures
- Data transformation: Pre-processing and post-processing hooks for data
- Partial updates: Better support for partial data updates
- Integration with ORMs: Direct integration with SQLAlchemy and other ORMs
- Field-level control: More granular control over individual fields
Here's a quick comparison of Flask-RESTful's marshaling and Marshmallow:
| Feature | Flask-RESTful Marshaling | Marshmallow |
|---|---|---|
| Serialization | Yes | Yes |
| Deserialization | No (separate RequestParser) | Yes (unified) |
| Validation | Basic | Advanced |
| Nested Structures | Limited | Excellent |
| ORM Integration | Limited | Extensive |
| Error Messages | Basic | Detailed, customizable |
| Partial Updates | Difficult | Built-in support |
Getting Started with Marshmallow
Let's set up Marshmallow and create a basic schema:
# Install Marshmallow
pip install marshmallow
# For Flask integration
pip install flask-marshmallow
# For SQLAlchemy integration
pip install marshmallow-sqlalchemy
Now, let's create a simple schema:
from marshmallow import Schema, fields
# Define a schema
class UserSchema(Schema):
id = fields.Integer(dump_only=True) # Read-only field
username = fields.String(required=True)
email = fields.Email(required=True)
created_at = fields.DateTime(dump_only=True)
bio = fields.String()
# Create an instance of the schema
user_schema = UserSchema()
users_schema = UserSchema(many=True) # For serializing collections
# Sample data
user_data = {
'id': 1,
'username': 'johndoe',
'email': 'john@example.com',
'created_at': '2023-01-15T12:30:45',
'bio': 'Software developer and tech enthusiast.'
}
# Serialize the data
result = user_schema.dump(user_data)
print(result) # JSON-compatible dict
# Deserialize and validate data
try:
user = user_schema.load({'username': 'janedoe', 'email': 'jane@example.com'})
print(user) # Valid data
except ValidationError as err:
print(err.messages) # Validation errors
This example demonstrates the basic usage of Marshmallow:
- Define a schema class that inherits from
Schema - Define fields with types and validation rules
- Create instances of the schema for single objects and collections
- Use
dump()to serialize data (Python to JSON) - Use
load()to deserialize and validate data (JSON to Python)
Field Types and Validation
Marshmallow provides a rich set of field types and validation options:
Common Field Types
| Field Type | Python Type | Description |
|---|---|---|
| fields.String | str | String field |
| fields.Integer | int | Integer field |
| fields.Float | float | Floating point field |
| fields.Boolean | bool | Boolean field |
| fields.DateTime | datetime.datetime | Date and time field |
| fields.Date | datetime.date | Date field |
| fields.Time | datetime.time | Time field |
| fields.Decimal | decimal.Decimal | Precise decimal field |
| fields.Email | str | Email field with validation |
| fields.URL | str | URL field with validation |
| fields.Dict | dict | Dictionary field |
| fields.List | list | List field |
| fields.Nested | dict/object | Nested object field |
Field Options
Fields can be configured with various options:
# Field with options
username = fields.String(
required=True, # Field is required during deserialization
validate=Length(min=3, max=50),# Apply validation constraints
error_messages={ # Custom error messages
'required': 'Username is required',
'validator_failed': 'Username must be between 3 and 50 characters'
},
dump_only=False, # Field can be both loaded and dumped
load_only=False, # Field can be both loaded and dumped
data_key='user_name', # Different key name in the data
attribute='name', # Different attribute name in the object
default='guest', # Default value if not provided
missing=None # Default value during deserialization
)
Common Validators
Marshmallow includes several built-in validators:
from marshmallow.validate import Length, Range, OneOf, Email, URL, Regexp
class UserSchema(Schema):
username = fields.String(validate=Length(min=3, max=50))
age = fields.Integer(validate=Range(min=18, max=120))
role = fields.String(validate=OneOf(['user', 'admin', 'editor']))
email = fields.String(validate=Email())
website = fields.String(validate=URL())
phone = fields.String(validate=Regexp(r'^\d{3}-\d{3}-\d{4}$'))
You can also create custom validators:
from marshmallow import ValidationError
def validate_not_in_blacklist(value):
blacklist = ['admin', 'root', 'superuser']
if value.lower() in blacklist:
raise ValidationError('This username is reserved.')
class UserSchema(Schema):
username = fields.String(validate=validate_not_in_blacklist)
Nested Schemas and Relationships
Marshmallow excels at handling complex, nested data structures, making it ideal for representing relationships between resources:
Nested Objects
class AddressSchema(Schema):
street = fields.String(required=True)
city = fields.String(required=True)
state = fields.String(required=True)
zip_code = fields.String(required=True)
country = fields.String(required=True)
class UserSchema(Schema):
id = fields.Integer(dump_only=True)
username = fields.String(required=True)
email = fields.Email(required=True)
# Nested schema for the user's address
address = fields.Nested(AddressSchema)
When serializing, this will produce nested JSON:
{
"id": 1,
"username": "johndoe",
"email": "john@example.com",
"address": {
"street": "123 Main St",
"city": "Anytown",
"state": "CA",
"zip_code": "12345",
"country": "USA"
}
}
Nested Collections
class CommentSchema(Schema):
id = fields.Integer(dump_only=True)
content = fields.String(required=True)
created_at = fields.DateTime(dump_only=True)
class PostSchema(Schema):
id = fields.Integer(dump_only=True)
title = fields.String(required=True)
content = fields.String(required=True)
created_at = fields.DateTime(dump_only=True)
# Nested collection of comments
comments = fields.Nested(CommentSchema, many=True)
This will serialize a post with its comments:
{
"id": 123,
"title": "My First Post",
"content": "Hello, world!",
"created_at": "2023-01-15T12:30:45",
"comments": [
{
"id": 1,
"content": "Great post!",
"created_at": "2023-01-15T14:22:10"
},
{
"id": 2,
"content": "Thanks for sharing.",
"created_at": "2023-01-16T09:45:30"
}
]
}
Self-Referential Nested Schemas
For hierarchical data like categories or comments with replies, you can use self-referential schemas:
class CategorySchema(Schema):
id = fields.Integer(dump_only=True)
name = fields.String(required=True)
description = fields.String()
# Self-reference for parent-child relationships
parent_id = fields.Integer(allow_none=True)
children = fields.Nested('self', many=True, exclude=('parent_id',))
Customizing Nested Fields
You can customize which fields are included in nested objects:
# Include only specific fields
author = fields.Nested(UserSchema, only=('id', 'username'))
# Exclude specific fields
author = fields.Nested(UserSchema, exclude=('email', 'bio'))
# Conditional nested fields
comments = fields.Nested(
CommentSchema,
many=True,
# Only include comments if the post has at least one
dump_default=[],
# Only include approved comments
dump_only_if=lambda obj: [c for c in obj.comments if c.approved]
)
Customizing Serialization and Deserialization
Marshmallow provides hooks for customizing how data is processed during serialization and deserialization:
Method Fields
Method fields let you compute values during serialization:
class UserSchema(Schema):
id = fields.Integer(dump_only=True)
username = fields.String(required=True)
first_name = fields.String(required=True)
last_name = fields.String(required=True)
# Method field for computed value
full_name = fields.Method('get_full_name')
def get_full_name(self, obj):
return f"{obj.first_name} {obj.last_name}"
Function Fields
Similar to method fields, but using a function:
def get_likes_count(obj):
return len(obj.likes)
class PostSchema(Schema):
id = fields.Integer(dump_only=True)
title = fields.String(required=True)
content = fields.String(required=True)
# Function field
likes_count = fields.Function(get_likes_count)
Pre-processing and Post-processing Hooks
Hooks allow you to modify data before or after processing:
class UserSchema(Schema):
id = fields.Integer(dump_only=True)
username = fields.String(required=True)
email = fields.Email(required=True)
password = fields.String(load_only=True, required=True)
# Pre-load hook: process data before validation
@pre_load
def lowercase_email(self, data, **kwargs):
# Convert email to lowercase
if 'email' in data and data['email']:
data['email'] = data['email'].lower()
return data
# Post-load hook: process data after validation
@post_load
def make_user(self, data, **kwargs):
# Convert dict to User object
return User(**data)
# Pre-dump hook: process object before serialization
@pre_dump
def prepare_data(self, obj, **kwargs):
# If obj is a dict, convert timestamp to datetime
if isinstance(obj, dict) and 'created_at' in obj and isinstance(obj['created_at'], str):
obj['created_at'] = datetime.fromisoformat(obj['created_at'])
return obj
# Post-dump hook: process output after serialization
@post_dump
def remove_none_values(self, data, **kwargs):
# Remove fields with None values
return {key: value for key, value in data.items() if value is not None}
Data Transformation Example
A common use case is transforming data during serialization or deserialization:
class UserProfileSchema(Schema):
id = fields.Integer(dump_only=True)
username = fields.String(required=True)
# Store preferences as JSON in DB, but expose as dict in API
preferences = fields.Dict()
@pre_load
def parse_preferences(self, data, **kwargs):
# If preferences is a JSON string, parse it
if 'preferences' in data and isinstance(data['preferences'], str):
try:
data['preferences'] = json.loads(data['preferences'])
except json.JSONDecodeError:
data['preferences'] = {}
return data
@post_dump
def format_data(self, data, **kwargs):
# Add metadata
data['_type'] = 'user_profile'
data['_links'] = {
'self': f"/api/users/{data['id']}",
'settings': f"/api/users/{data['id']}/settings"
}
return data
Integrating with Flask-SQLAlchemy
The marshmallow-sqlalchemy extension simplifies working with SQLAlchemy models:
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
# Initialize extensions
db = SQLAlchemy(app)
ma = Marshmallow(app)
# Define models
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False)
email = db.Column(db.String(120), unique=True, nullable=False)
bio = db.Column(db.Text)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
posts = db.relationship('Post', backref='author', lazy=True)
class Post(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(100), nullable=False)
content = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
# Define schemas
class UserSchema(ma.SQLAlchemySchema):
class Meta:
model = User
# Fields to include
fields = ('id', 'username', 'email', 'bio', 'created_at', 'posts')
# Auto-generated fields from the model
id = ma.auto_field()
username = ma.auto_field()
email = ma.auto_field()
bio = ma.auto_field()
created_at = ma.auto_field()
# Relationship field
posts = ma.Nested('PostSchema', many=True, exclude=('author',))
class PostSchema(ma.SQLAlchemySchema):
class Meta:
model = Post
# Include all fields
fields = ('id', 'title', 'content', 'created_at', 'user_id', 'author')
id = ma.auto_field()
title = ma.auto_field()
content = ma.auto_field()
created_at = ma.auto_field()
user_id = ma.auto_field()
# Nested author field
author = ma.Nested(UserSchema, only=('id', 'username'))
# Create schema instances
user_schema = UserSchema()
users_schema = UserSchema(many=True)
post_schema = PostSchema()
posts_schema = PostSchema(many=True)
This approach offers several advantages:
- Automatic field generation based on the model (
auto_field()) - Relationship handling that respects lazy loading
- Consistent field naming between models and schemas
- Easier maintenance when model fields change
You can also use the even more automatic SQLAlchemyAutoSchema:
class UserSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = User
include_fk = True # Include foreign keys
include_relationships = True # Include relationships
load_instance = True # Deserialize to model instances
class PostSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Post
include_fk = True
include_relationships = True
load_instance = True
# Override the author field for custom behavior
author = ma.Nested(UserSchema, only=('id', 'username'))
With load_instance=True, the schema's load() method will create model
instances directly, making it easy to deserialize API data to database objects.
Using Marshmallow with Flask-RESTful
Marshmallow can be used alongside Flask-RESTful to provide enhanced serialization and validation:
from flask import Flask
from flask_restful import Api, Resource
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
from marshmallow import ValidationError
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
# Initialize extensions
db = SQLAlchemy(app)
ma = Marshmallow(app)
api = Api(app)
# Define models
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False)
email = db.Column(db.String(120), unique=True, nullable=False)
# Define schemas
class UserSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = User
load_instance = True
# Create schema instances
user_schema = UserSchema()
users_schema = UserSchema(many=True)
# Define resources
class UserResource(Resource):
def get(self, user_id):
user = User.query.get_or_404(user_id)
return user_schema.dump(user)
def put(self, user_id):
user = User.query.get_or_404(user_id)
try:
# Deserialize and validate input data
updated_user = user_schema.load(request.json, instance=user, partial=True)
db.session.commit()
return user_schema.dump(updated_user)
except ValidationError as err:
return {'message': 'Validation error', 'errors': err.messages}, 400
def delete(self, user_id):
user = User.query.get_or_404(user_id)
db.session.delete(user)
db.session.commit()
return '', 204
class UserListResource(Resource):
def get(self):
users = User.query.all()
return users_schema.dump(users)
def post(self):
try:
# Create a new user instance from the data
user = user_schema.load(request.json)
db.session.add(user)
db.session.commit()
return user_schema.dump(user), 201
except ValidationError as err:
return {'message': 'Validation error', 'errors': err.messages}, 400
# Register resources
api.add_resource(UserResource, '/users/<int:user_id>')
api.add_resource(UserListResource, '/users')
The key differences from using Flask-RESTful's marshaling:
- Use
schema.dump()for serialization instead ofmarshal_with - Use
schema.load()for deserialization and validation instead ofreqparse - Handle
ValidationErrorexceptions for input validation errors - Use
partial=Truefor partial updates (PUT/PATCH) - Use
instance=objto update existing instances
Advanced Features
Field Masking
Allow clients to request only the fields they need:
@app.route('/users/<int:user_id>')
def get_user(user_id):
user = User.query.get_or_404(user_id)
# Get requested fields from query parameter
fields = request.args.get('fields')
if fields:
# Create a schema with only the requested fields
only = fields.split(',')
schema = UserSchema(only=only)
else:
schema = user_schema
return schema.dump(user)
Schema Inheritance
Create hierarchies of schemas to reuse field definitions:
class BasePersonSchema(Schema):
id = fields.Integer(dump_only=True)
first_name = fields.String(required=True)
last_name = fields.String(required=True)
email = fields.Email(required=True)
class UserSchema(BasePersonSchema):
username = fields.String(required=True)
password = fields.String(load_only=True, required=True)
role = fields.String(validate=OneOf(['user', 'admin']))
class EmployeeSchema(BasePersonSchema):
department = fields.String(required=True)
hire_date = fields.Date(required=True)
salary = fields.Decimal(places=2, load_only=True)
Context-dependent Serialization
Pass context to schemas to customize behavior:
class PostSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Post
# Field that changes based on context
content = fields.Method('get_content')
def get_content(self, obj):
# Get user role from context
user_role = self.context.get('user_role', 'guest')
# Show full content to authorized users
if user_role in ['admin', 'editor'] or obj.author_id == self.context.get('user_id'):
return obj.content
# Show preview to others
preview_length = self.context.get('preview_length', 100)
if len(obj.content) > preview_length:
return obj.content[:preview_length] + '...'
return obj.content
# Using context
admin_post = post_schema.dump(post, context={'user_role': 'admin'})
user_post = post_schema.dump(post, context={'user_role': 'user', 'user_id': current_user.id})
guest_post = post_schema.dump(post, context={'user_role': 'guest', 'preview_length': 50})
Schema Registries
For applications with many schemas, a registry can help manage them:
# Create a schema registry
schemas = {
'user': UserSchema(),
'users': UserSchema(many=True),
'post': PostSchema(),
'posts': PostSchema(many=True),
'comment': CommentSchema(),
'comments': CommentSchema(many=True),
'user_profile': UserProfileSchema(),
'category': CategorySchema(),
'categories': CategorySchema(many=True)
}
def get_schema(name, **kwargs):
"""Get a schema by name with optional context."""
schema = schemas.get(name)
if not schema:
raise ValueError(f"Unknown schema: {name}")
# Create a copy with context if provided
if kwargs:
schema = schema.__class__(**{**schema.__dict__})
schema.context.update(kwargs)
return schema
# Using the registry
user_data = get_schema('user', user_role='admin').dump(user)
posts_data = get_schema('posts', preview_length=200).dump(posts)
Error Handling and Validation
Proper error handling is crucial for API usability. Marshmallow provides detailed validation errors:
from marshmallow import ValidationError
from flask import jsonify, request
# Global error handler for validation errors
@app.errorhandler(ValidationError)
def handle_validation_error(err):
return jsonify({
'error': 'Validation error',
'messages': err.messages
}), 400
@app.route('/users', methods=['POST'])
def create_user():
try:
# Deserialize and validate the request data
user = user_schema.load(request.json)
db.session.add(user)
db.session.commit()
return user_schema.dump(user), 201
except ValidationError as err:
# This will be handled by the global error handler
raise err
A validation error response might look like:
{
"error": "Validation error",
"messages": {
"email": ["Not a valid email address."],
"username": ["Missing data for required field."],
"age": ["Must be greater than or equal to 18."]
}
}
Custom Error Messages
You can customize error messages for fields and validators:
class UserSchema(Schema):
username = fields.String(
required=True,
validate=Length(min=3, max=50),
error_messages={
'required': 'Please provide a username',
'invalid': 'Not a valid string',
'validator_failed': 'Username must be between 3 and 50 characters'
}
)
email = fields.Email(
required=True,
error_messages={
'required': 'Email address is required',
'invalid': 'Please provide a valid email address'
}
)
Field-level Validation
Add validation methods for specific fields:
class UserSchema(Schema):
username = fields.String(required=True)
email = fields.Email(required=True)
password = fields.String(required=True, load_only=True)
password_confirm = fields.String(required=True, load_only=True)
# Field-level validator
def validate_username(self, value):
# Check if username is already taken
existing = User.query.filter_by(username=value).first()
if existing:
raise ValidationError('This username is already taken')
return value
# Validate password strength
def validate_password(self, value):
if len(value) < 8:
raise ValidationError('Password must be at least 8 characters')
if not any(c.isupper() for c in value):
raise ValidationError('Password must contain at least one uppercase letter')
if not any(c.isdigit() for c in value):
raise ValidationError('Password must contain at least one number')
return value
# Schema-level validation
@validates_schema
def validate_passwords_match(self, data, **kwargs):
if data.get('password') != data.get('password_confirm'):
raise ValidationError('Passwords must match', 'password_confirm')
Schema-level Validation
Validate multiple fields together:
class EventSchema(Schema):
title = fields.String(required=True)
start_date = fields.Date(required=True)
end_date = fields.Date(required=True)
@validates_schema
def validate_dates(self, data, **kwargs):
if 'start_date' in data and 'end_date' in data:
if data['end_date'] < data['start_date']:
raise ValidationError('End date must be after start date', 'end_date')
Complete Example: Blog API with Marshmallow
Let's put it all together with a complete example of a blog API using Flask, SQLAlchemy, and Marshmallow:
from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
from marshmallow import ValidationError, validates, validates_schema
from datetime import datetime
# Initialize app
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///blog.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
# Initialize extensions
db = SQLAlchemy(app)
ma = Marshmallow(app)
# Define models
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(50), unique=True, nullable=False)
email = db.Column(db.String(100), unique=True, nullable=False)
password = db.Column(db.String(100), nullable=False) # Would be hashed in real app
bio = db.Column(db.Text)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
posts = db.relationship('Post', backref='author', lazy=True, cascade='all, delete-orphan')
comments = db.relationship('Comment', backref='author', lazy=True, cascade='all, delete-orphan')
class Category(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), unique=True, nullable=False)
description = db.Column(db.Text)
posts = db.relationship('Post', backref='category', lazy=True)
class Post(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(100), nullable=False)
content = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
published = db.Column(db.Boolean, default=False)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
category_id = db.Column(db.Integer, db.ForeignKey('category.id'), nullable=True)
comments = db.relationship('Comment', backref='post', lazy=True, cascade='all, delete-orphan')
class Comment(db.Model):
id = db.Column(db.Integer, primary_key=True)
content = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False)
# Create the database tables
with app.app_context():
db.create_all()
# Define schemas
class UserSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = User
exclude = ('password',) # Don't include password in serialization
load_instance = True # Deserialize to model instances
# Add validation
@validates('username')
def validate_username(self, value):
if len(value) < 3:
raise ValidationError('Username must be at least 3 characters')
return value
class CategorySchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Category
load_instance = True
class CommentSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Comment
include_fk = True
load_instance = True
# Add author information
class PostSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Post
include_fk = True
load_instance = True
# Add nested fields
author = ma.Nested(UserSchema, only= ('id', 'username'))
category = ma.Nested(CategorySchema, only= ('id', 'name'))
comments = ma.Nested(CommentSchema, many=True, exclude=('post',))
# Add computed fields
comment_count = ma.Function(lambda obj: len(obj.comments))
# Validation
@validates('title')
def validate_title(self, value):
if len(value) < 5:
raise ValidationError('Title must be at least 5 characters')
return value
# Schema instances
user_schema = UserSchema()
users_schema = UserSchema(many=True)
category_schema = CategorySchema()
categories_schema = CategorySchema(many=True)
post_schema = PostSchema()
posts_schema = PostSchema(many=True)
comment_schema = CommentSchema()
comments_schema = CommentSchema(many=True)
# Error handler
@app.errorhandler(ValidationError)
def handle_validation_error(err):
return jsonify({
'error': 'Validation error',
'messages': err.messages
}), 400
# User endpoints
@app.route('/users', methods=['GET'])
def get_users():
users = User.query.all()
return jsonify(users_schema.dump(users))
@app.route('/users/<int:user_id>', methods=['GET'])
def get_user(user_id):
user = User.query.get_or_404(user_id)
return jsonify(user_schema.dump(user))
@app.route('/users', methods=['POST'])
def create_user():
try:
user = user_schema.load(request.json)
db.session.add(user)
db.session.commit()
return jsonify(user_schema.dump(user)), 201
except ValidationError as err:
return jsonify({
'error': 'Validation error',
'messages': err.messages
}), 400
# Category endpoints
@app.route('/categories', methods=['GET'])
def get_categories():
categories = Category.query.all()
return jsonify(categories_schema.dump(categories))
@app.route('/categories/<int:category_id>', methods=['GET'])
def get_category(category_id):
category = Category.query.get_or_404(category_id)
return jsonify(category_schema.dump(category))
@app.route('/categories', methods=['POST'])
def create_category():
try:
category = category_schema.load(request.json)
db.session.add(category)
db.session.commit()
return jsonify(category_schema.dump(category)), 201
except ValidationError as err:
raise err
# Post endpoints
@app.route('/posts', methods=['GET'])
def get_posts():
# Add filtering
category_id = request.args.get('category_id', type=int)
user_id = request.args.get('user_id', type=int)
published = request.args.get('published')
# Start with base query
query = Post.query
# Apply filters
if category_id:
query = query.filter_by(category_id=category_id)
if user_id:
query = query.filter_by(user_id=user_id)
if published is not None:
is_published = published.lower() == 'true'
query = query.filter_by(published=is_published)
# Execute query and serialize
posts = query.order_by(Post.created_at.desc()).all()
result = posts_schema.dump(posts)
return jsonify(result)
@app.route('/posts/<int:post_id>', methods=['GET'])
def get_post(post_id):
post = Post.query.get_or_404(post_id)
return jsonify(post_schema.dump(post))
@app.route('/posts', methods=['POST'])
def create_post():
try:
post = post_schema.load(request.json)
db.session.add(post)
db.session.commit()
return jsonify(post_schema.dump(post)), 201
except ValidationError as err:
raise err
@app.route('/posts/<int:post_id>', methods=['PUT'])
def update_post(post_id):
post = Post.query.get_or_404(post_id)
try:
# Update existing instance with partial data
post = post_schema.load(request.json, instance=post, partial=True)
db.session.commit()
return jsonify(post_schema.dump(post))
except ValidationError as err:
raise err
@app.route('/posts/<int:post_id>', methods=['DELETE'])
def delete_post(post_id):
post = Post.query.get_or_404(post_id)
db.session.delete(post)
db.session.commit()
return '', 204
# Comment endpoints
@app.route('/posts/<int:post_id>/comments', methods=['GET'])
def get_post_comments(post_id):
post = Post.query.get_or_404(post_id)
return jsonify(comments_schema.dump(post.comments))
@app.route('/posts/<int:post_id>/comments', methods=['POST'])
def add_comment(post_id):
post = Post.query.get_or_404(post_id)
try:
# Create comment linked to the post
data = request.json
data['post_id'] = post_id
comment = comment_schema.load(data)
db.session.add(comment)
db.session.commit()
return jsonify(comment_schema.dump(comment)), 201
except ValidationError as err:
raise err
if __name__ == '__main__':
app.run(debug=True)
This example demonstrates a complete blog API with:
- SQLAlchemy models with relationships
- Marshmallow schemas for serialization, deserialization, and validation
- Nested fields for related resources
- Custom validation rules
- Computed fields
- Filtering capabilities
- Complete CRUD operations for all resources
Practical Activity: Task Manager API with Marshmallow
Let's apply what we've learned by enhancing the task manager API from the previous lecture with Marshmallow:
Here's the outline:
- Set up Flask, SQLAlchemy, and Marshmallow
- Create Task, Category, and User models
- Define schemas for each model with appropriate validation
- Implement API endpoints with CRUD operations
- Add filtering, sorting, and pagination
- Implement proper error handling
Start by creating the following files:
task_manager/
├── app.py # Main application file
├── models.py # Database models
├── schemas.py # Marshmallow schemas
├── api.py # API routes
└── requirements.txt # Dependencies
Implementation steps:
- Define Task, Category, and User models with appropriate relationships
- Create schemas for each model with validation rules
- Implement API endpoints for all models
- Add validation and error handling
- Implement filtering, sorting, and pagination
- Test the API endpoints
This activity will help you practice using Marshmallow for serialization, deserialization, and validation in a real-world API.
Key Takeaways
- Marshmallow provides powerful serialization, deserialization, and validation capabilities for Flask APIs
- Schemas define how to convert between Python objects and serialized representations
- Field types and validation rules ensure data quality and consistency
- Nested schemas make it easy to handle complex, related data structures
- Pre-processing and post-processing hooks allow for custom data transformations
- Integration with SQLAlchemy simplifies working with database models
- Marshmallow provides detailed validation errors for better API usability
- Advanced features like partial updates, schema inheritance, and context-dependent serialization make Marshmallow suitable for complex applications