trainvel/trainvel/gtfs/management/commands/update_trainvel_gtfs.py

410 lines
19 KiB
Python

import csv
import os.path
import tempfile
from datetime import datetime, timedelta
from time import time
from zipfile import ZipFile
from zoneinfo import ZoneInfo
import requests
from django.core.management import BaseCommand
from django.db import transaction
from tqdm import tqdm
from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \
Transfer, Trip, PickupType, TripUpdate
class Command(BaseCommand):
help = "Update the Trainvel GTFS database."
def add_arguments(self, parser):
parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode")
parser.add_argument('--bulk_size', '-b', type=int, default=2000, help="Number of objects to create in bulk.")
parser.add_argument('--dry-run', action='store_true',
help="Do not update the database, only print what would be done.")
parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.")
def handle(self, debug: bool = False, bulk_size: int = 100, dry_run: bool = False, force: bool = False,
verbosity: int = 1, *args, **options):
if dry_run:
self.stdout.write(self.style.WARNING("Dry run mode activated."))
self.stdout.write("Updating database...")
for gtfs_feed in GTFSFeed.objects.all():
if not force:
# Check if the source file was updated
resp = requests.head(gtfs_feed.feed_url, allow_redirects=True)
if 'ETag' in resp.headers and gtfs_feed.etag:
if resp.headers['ETag'] == gtfs_feed.etag:
if verbosity >= 1:
self.stdout.write(self.style.WARNING(f"Database is already up-to-date for {gtfs_feed}."))
continue
if 'Last-Modified' in resp.headers and gtfs_feed.last_modified:
last_modified = resp.headers['Last-Modified']
last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
if last_modified <= gtfs_feed.last_modified:
if verbosity >= 1:
self.stdout.write(self.style.WARNING(f"Database is already up-to-date for {gtfs_feed}."))
continue
self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...")
resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True)
with tempfile.TemporaryFile(suffix=".zip") as file:
for chunk in resp.iter_content(chunk_size=128):
file.write(chunk)
file.seek(0)
with tempfile.TemporaryDirectory() as tmp_dir:
with ZipFile(file) as zipfile:
zipfile.extractall(tmp_dir)
self.parse_gtfs(tmp_dir, gtfs_feed, bulk_size, dry_run, verbosity)
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()
def parse_gtfs(self, zip_dir: str, gtfs_feed: GTFSFeed, bulk_size: int, dry_run: bool, verbosity: int):
gtfs_code = gtfs_feed.code
def read_csv(filename):
with open(os.path.join(zip_dir, filename), 'r') as f:
reader = csv.DictReader(f)
reader.fieldnames = [field.replace('\ufeff', '').strip()
for field in reader.fieldnames]
iterator = tqdm(reader, desc=filename, unit=' rows') if verbosity >= 2 else reader
for row in iterator:
yield {k.strip(): v.strip() for k, v in row.items()}
agencies = []
for agency_dict in read_csv("agency.txt"):
agency_dict: dict
agency = Agency(
id=f"{gtfs_code}-{agency_dict['agency_id']}",
name=agency_dict['agency_name'],
url=agency_dict['agency_url'],
timezone=agency_dict['agency_timezone'],
lang=agency_dict.get('agency_lang', "fr"),
phone=agency_dict.get('agency_phone', ""),
email=agency_dict.get('agency_email', ""),
gtfs_feed_id=gtfs_code,
)
agencies.append(agency)
if agencies and not dry_run:
Agency.objects.bulk_create(agencies,
update_conflicts=True,
update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email',
'gtfs_feed'],
unique_fields=['id'])
agencies.clear()
stops = []
for stop_dict in read_csv("stops.txt"):
stop_dict: dict
stop_id = stop_dict['stop_id']
stop_id = f"{gtfs_code}-{stop_id}"
parent_station_id = stop_dict.get('parent_station', None)
parent_station_id = f"{gtfs_code}-{parent_station_id}" if parent_station_id else None
stop = Stop(
id=stop_id,
name=stop_dict['stop_name'],
desc=stop_dict.get('stop_desc', ""),
lat=stop_dict['stop_lat'],
lon=stop_dict['stop_lon'],
zone_id=stop_dict.get('zone_id', ""),
url=stop_dict.get('stop_url', ""),
location_type=stop_dict.get('location_type', 0) or 0,
parent_station_id=parent_station_id,
timezone=stop_dict.get('stop_timezone', ""),
wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0),
level_id=stop_dict.get('level_id', ""),
platform_code=stop_dict.get('platform_code', ""),
gtfs_feed_id=gtfs_code,
)
stops.append(stop)
if stops and not dry_run:
Stop.objects.bulk_create(stops,
batch_size=bulk_size,
update_conflicts=True,
update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url',
'location_type', 'parent_station_id', 'timezone',
'wheelchair_boarding', 'level_id', 'platform_code',
'gtfs_feed'],
unique_fields=['id'])
stops.clear()
routes = []
for route_dict in read_csv("routes.txt"):
route_dict: dict
route_id = route_dict['route_id']
route_id = f"{gtfs_code}-{route_id}"
# Agency is optional there is only one
agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed_id=gtfs_code)
route = Route(
id=route_id,
agency_id=f"{gtfs_code}-{agency_id}",
short_name=route_dict['route_short_name'],
long_name=route_dict['route_long_name'],
desc=route_dict.get('route_desc', ""),
type=route_dict['route_type'],
url=route_dict.get('route_url', ""),
color=route_dict.get('route_color', ""),
text_color=route_dict.get('route_text_color', ""),
gtfs_feed_id=gtfs_code,
)
routes.append(route)
if len(routes) >= bulk_size and not dry_run:
Route.objects.bulk_create(routes,
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'gtfs_feed'],
unique_fields=['id'])
routes.clear()
if routes and not dry_run:
Route.objects.bulk_create(routes,
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'gtfs_feed'],
unique_fields=['id'])
routes.clear()
start_time = 0
if verbosity >= 1:
self.stdout.write("Deleting old calendars, trips and stop times…")
start_time = time()
TripUpdate.objects.filter(trip__gtfs_feed_id=gtfs_code).delete()
StopTime.objects.filter(trip__gtfs_feed_id=gtfs_code)._raw_delete(StopTime.objects.db)
Trip.objects.filter(gtfs_feed_id=gtfs_code)._raw_delete(Trip.objects.db)
Calendar.objects.filter(gtfs_feed_id=gtfs_code).delete()
if verbosity >= 1:
end = time()
self.stdout.write(f"Done in {end - start_time:.2f} s")
calendars = {}
if os.path.exists(os.path.join(zip_dir, "calendar.txt")):
for calendar_dict in read_csv("calendar.txt"):
calendar_dict: dict
calendar = Calendar(
id=f"{gtfs_code}-{calendar_dict['service_id']}",
monday=calendar_dict['monday'],
tuesday=calendar_dict['tuesday'],
wednesday=calendar_dict['wednesday'],
thursday=calendar_dict['thursday'],
friday=calendar_dict['friday'],
saturday=calendar_dict['saturday'],
sunday=calendar_dict['sunday'],
start_date=calendar_dict['start_date'],
end_date=calendar_dict['end_date'],
gtfs_feed_id=gtfs_code,
)
calendars[calendar.id] = calendar
if len(calendars) >= bulk_size and not dry_run:
Calendar.objects.bulk_create(calendars.values(), batch_size=bulk_size)
calendars.clear()
if calendars and not dry_run:
Calendar.objects.bulk_create(calendars.values(), batch_size=bulk_size)
calendars.clear()
calendar_dates = []
all_calendars = {calendar.id: calendar for calendar in Calendar.objects.filter(gtfs_feed_id=gtfs_code)}
new_calendars = {}
with transaction.atomic():
for calendar_date_dict in read_csv("calendar_dates.txt"):
calendar_date_dict: dict
service_id = f"{gtfs_code}-{calendar_date_dict['service_id']}"
date = datetime.fromisoformat(calendar_date_dict['date']).date()
calendar_date = CalendarDate(
id=f"{gtfs_code}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}",
service_id=service_id,
date=calendar_date_dict['date'],
exception_type=calendar_date_dict['exception_type'],
)
calendar_dates.append(calendar_date)
if service_id not in all_calendars:
calendar = Calendar(
id=service_id,
monday=False,
tuesday=False,
wednesday=False,
thursday=False,
friday=False,
saturday=False,
sunday=False,
start_date=date,
end_date=date,
gtfs_feed_id=gtfs_code,
)
all_calendars[service_id] = calendar
new_calendars[service_id] = calendar
else:
calendar = all_calendars[service_id]
if calendar.start_date > date:
calendar.start_date = date
if calendar.end_date < date:
calendar.end_date = date
if len(calendar_dates) >= bulk_size and not dry_run:
CalendarDate.objects.bulk_create(calendar_dates, batch_size=bulk_size)
calendar_dates.clear()
if (calendar_dates or new_calendars) and not dry_run:
Calendar.objects.bulk_create(new_calendars.values(), batch_size=bulk_size)
CalendarDate.objects.bulk_create(calendar_dates, batch_size=bulk_size)
new_calendars.clear()
calendar_dates.clear()
trips = []
# start_time = time()
for trip_dict in read_csv("trips.txt"):
trip_dict: dict
trip_id = trip_dict['trip_id']
route_id = trip_dict['route_id']
trip_id = f"{gtfs_code}-{trip_id}"
route_id = f"{gtfs_code}-{route_id}"
trip = Trip(
id=trip_id,
route_id=route_id,
service_id=f"{gtfs_code}-{trip_dict['service_id']}",
headsign=trip_dict.get('trip_headsign', ""),
short_name=trip_dict.get('trip_short_name', ""),
direction_id=trip_dict.get('direction_id', None) or None,
block_id=trip_dict.get('block_id', ""),
shape_id=trip_dict.get('shape_id', ""),
wheelchair_accessible=trip_dict.get('wheelchair_accessible', None),
bikes_allowed=trip_dict.get('bikes_allowed', None),
gtfs_feed_id=gtfs_code,
)
trips.append(trip)
if len(trips) >= bulk_size and not dry_run:
# now = time()
# print(f"Elapsed time: {now - start_time:.3f}s, "
# f"{1000 * (now - start_time) / len(trips):.2f}ms per iteration")
# start_time = now
Trip.objects.bulk_create(trips)
# now = time()
# print(f"Elapsed time: {now - start_time:.3f}s to save")
# start_time = now
trips.clear()
if trips and not dry_run:
Trip.objects.bulk_create(trips)
trips.clear()
stop_times = []
# start_time = time()
for stop_time_dict in read_csv("stop_times.txt"):
stop_time_dict: dict
stop_id = stop_time_dict['stop_id']
stop_id = f"{gtfs_code}-{stop_id}"
trip_id = stop_time_dict['trip_id']
trip_id = f"{gtfs_code}-{trip_id}"
arr_time = stop_time_dict['arrival_time']
arr_h, arr_m, arr_s = map(int, arr_time.split(':'))
arr_time = arr_h * 3600 + arr_m * 60 + arr_s
dep_time = stop_time_dict['departure_time']
dep_h, dep_m, dep_s = map(int, dep_time.split(':'))
dep_time = dep_h * 3600 + dep_m * 60 + dep_s
pickup_type = stop_time_dict.get('pickup_type', PickupType.REGULAR)
drop_off_type = stop_time_dict.get('drop_off_type', PickupType.REGULAR)
# if stop_time_dict['stop_sequence'] == "1":
# # First stop
# drop_off_type = PickupType.NONE
# elif arr_time == dep_time:
# # Last stop
# pickup_type = PickupType.NONE
st = StopTime(
id=f"{gtfs_code}-{stop_time_dict['trip_id']}-{stop_time_dict['stop_sequence']}",
trip_id=trip_id,
arrival_time=timedelta(seconds=arr_time),
departure_time=timedelta(seconds=dep_time),
stop_id=stop_id,
stop_sequence=stop_time_dict['stop_sequence'],
stop_headsign=stop_time_dict.get('stop_headsign', ""),
pickup_type=pickup_type,
drop_off_type=drop_off_type,
timepoint=stop_time_dict.get('timepoint', None),
)
stop_times.append(st)
if len(stop_times) >= bulk_size and not dry_run:
# now = time()
# print(f"Elapsed time: {now - start_time:.3f}s, "
# f"{1000 * (now - start_time) / len(stop_times):.2f}ms per iteration")
# start_time = now
StopTime.objects.bulk_create(stop_times)
# now = time()
# print(f"Elapsed time: {now - start_time:.3f}s to save")
# start_time = now
stop_times.clear()
if stop_times and not dry_run:
StopTime.objects.bulk_create(stop_times)
stop_times.clear()
if os.path.exists(os.path.join(zip_dir, "transfers.txt")):
transfers = []
for transfer_dict in read_csv("transfers.txt"):
transfer_dict: dict
from_stop_id = transfer_dict['from_stop_id']
to_stop_id = transfer_dict['to_stop_id']
from_stop_id = f"{gtfs_code}-{from_stop_id}"
to_stop_id = f"{gtfs_code}-{to_stop_id}"
transfer = Transfer(
id=f"{gtfs_code}-{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}",
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
transfer_type=transfer_dict['transfer_type'],
min_transfer_time=transfer_dict.get('min_transfer_time', 0) or 0,
)
transfers.append(transfer)
if len(transfers) >= bulk_size and not dry_run:
Transfer.objects.bulk_create(transfers)
transfers.clear()
if transfers and not dry_run:
Transfer.objects.bulk_create(transfers)
transfers.clear()
if os.path.exists(os.path.join(zip_dir, "feed_info.txt")) and not dry_run:
for feed_info_dict in read_csv("feed_info.txt"):
feed_info_dict: dict
FeedInfo.objects.update_or_create(
publisher_name=feed_info_dict['feed_publisher_name'],
gtfs_feed_id=gtfs_code,
defaults=dict(
publisher_url=feed_info_dict['feed_publisher_url'],
lang=feed_info_dict['feed_lang'],
start_date=feed_info_dict.get('feed_start_date', datetime.now().date()),
end_date=feed_info_dict.get('feed_end_date', datetime.now().date()),
version=feed_info_dict.get('feed_version', 1),
)
)