Merge branch 'master' into master

This commit is contained in:
Elliott Ashby
2025-07-26 19:01:19 +09:00
committed by GitHub
56 changed files with 3948 additions and 2009 deletions

View File

@@ -1,30 +0,0 @@
You are a senior Python developer with extensive experience in building robust, scalable applications. You excel at:
## Core Python Expertise
- Writing clean, maintainable, and efficient Python code
- Following PEP 8 style guidelines and Python best practices
- Implementing proper error handling and logging
- Using type hints and modern Python features (3.8+)
- Understanding memory management and performance optimization
## Development Practices
- Test-driven development (TDD) and writing comprehensive unit tests
- Designing modular, reusable code architectures
- Implementing design patterns appropriately
- Documentation and code commenting best practices
## Technical Skills
- CLI application development (argparse, click, typer)
## Problem-Solving Approach
- Break down complex problems into manageable components
- Consider edge cases and error scenarios
- Optimize for readability first, then performance
- Provide multiple solution approaches when applicable
- Explain trade-offs and design decisions
Always provide production-ready code with proper error handling, logging, and documentation.

View File

@@ -0,0 +1,304 @@
#!/bin/bash
#
# FZF Dynamic Preview Script Template
#
# This script handles previews for dynamic search results by parsing the JSON
# search results file and extracting info for the selected item.
# The placeholders in curly braces are dynamically filled by Python using .replace()
WIDTH=${FZF_PREVIEW_COLUMNS:-80}
IMAGE_RENDERER="{IMAGE_RENDERER}"
SEARCH_RESULTS_FILE="{SEARCH_RESULTS_FILE}"
IMAGE_CACHE_PATH="{IMAGE_CACHE_PATH}"
INFO_CACHE_PATH="{INFO_CACHE_PATH}"
PATH_SEP="{PATH_SEP}"
# Color codes injected by Python
C_TITLE="{C_TITLE}"
C_KEY="{C_KEY}"
C_VALUE="{C_VALUE}"
C_RULE="{C_RULE}"
RESET="{RESET}"
# Selected item from fzf
SELECTED_ITEM="$1"
generate_sha256() {
local input="$1"
if command -v sha256sum &>/dev/null; then
echo -n "$input" | sha256sum | awk '{print $1}'
elif command -v shasum &>/dev/null; then
echo -n "$input" | shasum -a 256 | awk '{print $1}'
elif command -v sha256 &>/dev/null; then
echo -n "$input" | sha256 | awk '{print $1}'
elif command -v openssl &>/dev/null; then
echo -n "$input" | openssl dgst -sha256 | awk '{print $2}'
else
echo -n "$input" | base64 | tr '/+' '_-' | tr -d '\n'
fi
}
fzf_preview() {
file=$1
dim=${FZF_PREVIEW_COLUMNS}x${FZF_PREVIEW_LINES}
if [ "$dim" = x ]; then
dim=$(stty size </dev/tty | awk "{print \$2 \"x\" \$1}")
fi
if ! [ "$IMAGE_RENDERER" = "icat" ] && [ -z "$KITTY_WINDOW_ID" ] && [ "$((FZF_PREVIEW_TOP + FZF_PREVIEW_LINES))" -eq "$(stty size </dev/tty | awk "{print \$1}")" ]; then
dim=${FZF_PREVIEW_COLUMNS}x$((FZF_PREVIEW_LINES - 1))
fi
if [ "$IMAGE_RENDERER" = "icat" ] && [ -z "$GHOSTTY_BIN_DIR" ]; then
if command -v kitten >/dev/null 2>&1; then
kitten icat --clear --transfer-mode=memory --unicode-placeholder --stdin=no --place="$dim@0x0" "$file" | sed "\$d" | sed "$(printf "\$s/\$/\033[m/")"
elif command -v icat >/dev/null 2>&1; then
icat --clear --transfer-mode=memory --unicode-placeholder --stdin=no --place="$dim@0x0" "$file" | sed "\$d" | sed "$(printf "\$s/\$/\033[m/")"
else
kitty icat --clear --transfer-mode=memory --unicode-placeholder --stdin=no --place="$dim@0x0" "$file" | sed "\$d" | sed "$(printf "\$s/\$/\033[m/")"
fi
elif [ -n "$GHOSTTY_BIN_DIR" ]; then
if command -v kitten >/dev/null 2>&1; then
kitten icat --clear --transfer-mode=memory --unicode-placeholder --stdin=no --place="$dim@0x0" "$file" | sed "\$d" | sed "$(printf "\$s/\$/\033[m/")"
elif command -v icat >/dev/null 2>&1; then
icat --clear --transfer-mode=memory --unicode-placeholder --stdin=no --place="$dim@0x0" "$file" | sed "\$d" | sed "$(printf "\$s/\$/\033[m/")"
else
chafa -s "$dim" "$file"
fi
elif command -v chafa >/dev/null 2>&1; then
case "$PLATFORM" in
android) chafa -s "$dim" "$file" ;;
windows) chafa -f sixel -s "$dim" "$file" ;;
*) chafa -s "$dim" "$file" ;;
esac
echo
elif command -v imgcat >/dev/null; then
imgcat -W "${dim%%x*}" -H "${dim##*x}" "$file"
else
echo please install a terminal image viewer
echo either icat for kitty terminal and wezterm or imgcat or chafa
fi
}
print_kv() {
local key="$1"
local value="$2"
local key_len=${#key}
local value_len=${#value}
local multiplier="${3:-1}"
local padding_len=$((WIDTH - key_len - 2 - value_len * multiplier))
if [ "$padding_len" -lt 1 ]; then
padding_len=1
value=$(echo $value| fold -s -w "$((WIDTH - key_len - 3))")
printf "${C_KEY}%s:${RESET}%*s%s\\n" "$key" "$padding_len" "" " $value"
else
printf "${C_KEY}%s:${RESET}%*s%s\\n" "$key" "$padding_len" "" " $value"
fi
}
draw_rule() {
ll=2
while [ $ll -le $FZF_PREVIEW_COLUMNS ];do
echo -n -e "${C_RULE}${RESET}"
((ll++))
done
echo
}
clean_html() {
echo "$1" | sed 's/<[^>]*>//g' | sed 's/&lt;/</g' | sed 's/&gt;/>/g' | sed 's/&amp;/\&/g' | sed 's/&quot;/"/g' | sed "s/&#39;/'/g"
}
format_date() {
local date_obj="$1"
if [ "$date_obj" = "null" ] || [ -z "$date_obj" ]; then
echo "N/A"
return
fi
# Extract year, month, day from the date object
if command -v jq >/dev/null 2>&1; then
year=$(echo "$date_obj" | jq -r '.year // "N/A"' 2>/dev/null || echo "N/A")
month=$(echo "$date_obj" | jq -r '.month // ""' 2>/dev/null || echo "")
day=$(echo "$date_obj" | jq -r '.day // ""' 2>/dev/null || echo "")
else
year=$(echo "$date_obj" | python3 -c "import json, sys; data=json.load(sys.stdin); print(data.get('year', 'N/A'))" 2>/dev/null || echo "N/A")
month=$(echo "$date_obj" | python3 -c "import json, sys; data=json.load(sys.stdin); print(data.get('month', ''))" 2>/dev/null || echo "")
day=$(echo "$date_obj" | python3 -c "import json, sys; data=json.load(sys.stdin); print(data.get('day', ''))" 2>/dev/null || echo "")
fi
if [ "$year" = "N/A" ] || [ "$year" = "null" ]; then
echo "N/A"
elif [ -n "$month" ] && [ "$month" != "null" ] && [ -n "$day" ] && [ "$day" != "null" ]; then
echo "$day/$month/$year"
elif [ -n "$month" ] && [ "$month" != "null" ]; then
echo "$month/$year"
else
echo "$year"
fi
}
# If no selection or search results file doesn't exist, show placeholder
if [ -z "$SELECTED_ITEM" ] || [ ! -f "$SEARCH_RESULTS_FILE" ]; then
echo "${C_TITLE}Dynamic Search Preview${RESET}"
draw_rule
echo "Type to search for anime..."
echo "Results will appear here as you type."
exit 0
fi
# Parse the search results JSON and find the matching item
if command -v jq >/dev/null 2>&1; then
# Use jq for faster and more reliable JSON parsing
MEDIA_DATA=$(cat "$SEARCH_RESULTS_FILE" | jq --arg selected "$SELECTED_ITEM" '
.data.Page.media[]? |
select(
((.title.english // .title.romaji // .title.native // "Unknown") +
" (" + (.startDate.year // "Unknown" | tostring) + ") " +
"[" + (.status // "Unknown") + "] - " +
((.genres[:3] // []) | join(", ") | if . == "" then "Unknown" else . end)
) == $selected
)
' 2>/dev/null)
else
# Fallback to Python for JSON parsing
MEDIA_DATA=$(cat "$SEARCH_RESULTS_FILE" | python3 -c "
import json
import sys
try:
data = json.load(sys.stdin)
selected_item = '''$SELECTED_ITEM'''
if 'data' not in data or 'Page' not in data['data'] or 'media' not in data['data']['Page']:
sys.exit(1)
media_list = data['data']['Page']['media']
for media in media_list:
title = media.get('title', {})
english_title = title.get('english') or title.get('romaji') or title.get('native', 'Unknown')
year = media.get('startDate', {}).get('year', 'Unknown') if media.get('startDate') else 'Unknown'
status = media.get('status', 'Unknown')
genres = ', '.join(media.get('genres', [])[:3]) or 'Unknown'
display_format = f'{english_title} ({year}) [{status}] - {genres}'
if selected_item.strip() == display_format.strip():
json.dump(media, sys.stdout, indent=2)
sys.exit(0)
sys.exit(1)
except Exception as e:
print(f'Error: {e}', file=sys.stderr)
sys.exit(1)
" 2>/dev/null)
fi
# If we couldn't find the media data, show error
if [ $? -ne 0 ] || [ -z "$MEDIA_DATA" ]; then
echo "${C_TITLE}Preview Error${RESET}"
draw_rule
echo "Could not load preview data for:"
echo "$SELECTED_ITEM"
exit 0
fi
# Extract information from the media data
if command -v jq >/dev/null 2>&1; then
# Use jq for faster extraction
TITLE=$(echo "$MEDIA_DATA" | jq -r '.title.english // .title.romaji // .title.native // "Unknown"' 2>/dev/null || echo "Unknown")
STATUS=$(echo "$MEDIA_DATA" | jq -r '.status // "Unknown"' 2>/dev/null || echo "Unknown")
FORMAT=$(echo "$MEDIA_DATA" | jq -r '.format // "Unknown"' 2>/dev/null || echo "Unknown")
EPISODES=$(echo "$MEDIA_DATA" | jq -r '.episodes // "Unknown"' 2>/dev/null || echo "Unknown")
DURATION=$(echo "$MEDIA_DATA" | jq -r 'if .duration then "\(.duration) min" else "Unknown" end' 2>/dev/null || echo "Unknown")
SCORE=$(echo "$MEDIA_DATA" | jq -r 'if .averageScore then "\(.averageScore)/100" else "N/A" end' 2>/dev/null || echo "N/A")
FAVOURITES=$(echo "$MEDIA_DATA" | jq -r '.favourites // 0' 2>/dev/null | sed ':a;s/\B[0-9]\{3\}\>/,&/;ta' || echo "0")
POPULARITY=$(echo "$MEDIA_DATA" | jq -r '.popularity // 0' 2>/dev/null | sed ':a;s/\B[0-9]\{3\}\>/,&/;ta' || echo "0")
GENRES=$(echo "$MEDIA_DATA" | jq -r '(.genres[:5] // []) | join(", ") | if . == "" then "Unknown" else . end' 2>/dev/null || echo "Unknown")
DESCRIPTION=$(echo "$MEDIA_DATA" | jq -r '.description // "No description available."' 2>/dev/null || echo "No description available.")
# Get start and end dates as JSON objects
START_DATE_OBJ=$(echo "$MEDIA_DATA" | jq -c '.startDate' 2>/dev/null || echo "null")
END_DATE_OBJ=$(echo "$MEDIA_DATA" | jq -c '.endDate' 2>/dev/null || echo "null")
# Get cover image URL
COVER_IMAGE=$(echo "$MEDIA_DATA" | jq -r '.coverImage.large // ""' 2>/dev/null || echo "")
else
# Fallback to Python for extraction
TITLE=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); title=data.get('title',{}); print(title.get('english') or title.get('romaji') or title.get('native', 'Unknown'))" 2>/dev/null || echo "Unknown")
STATUS=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); print(data.get('status', 'Unknown'))" 2>/dev/null || echo "Unknown")
FORMAT=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); print(data.get('format', 'Unknown'))" 2>/dev/null || echo "Unknown")
EPISODES=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); print(data.get('episodes', 'Unknown'))" 2>/dev/null || echo "Unknown")
DURATION=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); duration=data.get('duration'); print(f'{duration} min' if duration else 'Unknown')" 2>/dev/null || echo "Unknown")
SCORE=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); score=data.get('averageScore'); print(f'{score}/100' if score else 'N/A')" 2>/dev/null || echo "N/A")
FAVOURITES=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); print(f\"{data.get('favourites', 0):,}\")" 2>/dev/null || echo "0")
POPULARITY=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); print(f\"{data.get('popularity', 0):,}\")" 2>/dev/null || echo "0")
GENRES=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); print(', '.join(data.get('genres', [])[:5]))" 2>/dev/null || echo "Unknown")
DESCRIPTION=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); print(data.get('description', 'No description available.'))" 2>/dev/null || echo "No description available.")
# Get start and end dates
START_DATE_OBJ=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); json.dump(data.get('startDate'), sys.stdout)" 2>/dev/null || echo "null")
END_DATE_OBJ=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); json.dump(data.get('endDate'), sys.stdout)" 2>/dev/null || echo "null")
# Get cover image URL
COVER_IMAGE=$(echo "$MEDIA_DATA" | python3 -c "import json, sys; data=json.load(sys.stdin); cover=data.get('coverImage',{}); print(cover.get('large', ''))" 2>/dev/null || echo "")
fi
# Format the dates
START_DATE=$(format_date "$START_DATE_OBJ")
END_DATE=$(format_date "$END_DATE_OBJ")
# Generate cache hash for this item
CACHE_HASH=$(generate_sha256 "dynamic_search_$TITLE")
# Try to show image if available
if [ "{PREVIEW_MODE}" = "full" ] || [ "{PREVIEW_MODE}" = "image" ]; then
image_file="${IMAGE_CACHE_PATH}${PATH_SEP}${CACHE_HASH}.png"
# If image not cached and we have a URL, try to download it quickly
if [ ! -f "$image_file" ] && [ -n "$COVER_IMAGE" ]; then
if command -v curl >/dev/null 2>&1; then
# Quick download with timeout
curl -s -m 3 -L "$COVER_IMAGE" -o "$image_file" 2>/dev/null || rm -f "$image_file" 2>/dev/null
fi
fi
if [ -f "$image_file" ]; then
fzf_preview "$image_file"
else
echo "🖼️ Loading image..."
fi
echo
fi
# Display text info if configured
if [ "{PREVIEW_MODE}" = "full" ] || [ "{PREVIEW_MODE}" = "text" ]; then
draw_rule
print_kv "Title" "$TITLE"
draw_rule
print_kv "Score" "$SCORE"
print_kv "Favourites" "$FAVOURITES"
print_kv "Popularity" "$POPULARITY"
print_kv "Status" "$STATUS"
draw_rule
print_kv "Episodes" "$EPISODES"
print_kv "Duration" "$DURATION"
print_kv "Format" "$FORMAT"
draw_rule
print_kv "Genres" "$GENRES"
print_kv "Start Date" "$START_DATE"
print_kv "End Date" "$END_DATE"
draw_rule
# Clean and display description
CLEAN_DESCRIPTION=$(clean_html "$DESCRIPTION")
echo "$CLEAN_DESCRIPTION" | fold -s -w "$WIDTH"
fi

View File

@@ -1,77 +1,122 @@
fetch_anime_for_fzf() {
local search_term="$1"
if [ -z "$search_term" ]; then exit 0; fi
#!/bin/bash
#
# FZF Dynamic Search Script Template
#
# This script is a template for dynamic search functionality in fzf.
# The placeholders in curly braces, like {QUERY} are dynamically filled by Python using .replace()
local query='
query ($search: String) {
Page(page: 1, perPage: 25) {
media(search: $search, type: ANIME, sort: [SEARCH_MATCH]) {
id
title { romaji english }
meanScore
format
status
}
}
}
'
# Configuration variables (injected by Python)
GRAPHQL_ENDPOINT="{GRAPHQL_ENDPOINT}"
CACHE_DIR="{CACHE_DIR}"
SEARCH_RESULTS_FILE="{SEARCH_RESULTS_FILE}"
AUTH_HEADER="{AUTH_HEADER}"
local json_payload
json_payload=$(jq -n --arg query "$query" --arg search "$search_term" \
'{query: $query, variables: {search: $search}}')
# Get the current query from fzf
QUERY="$1"
curl --silent \
--header "Content-Type: application/json" \
--header "Accept: application/json" \
--request POST \
--data "$json_payload" \
https://graphql.anilist.co |
jq -r '.data.Page.media[]? | select(.title.romaji) |
"\(.title.english // .title.romaji) | Score: \(.meanScore // "N/A") | ID: \(.id)"'
# If query is empty, exit with empty results
if [ -z "$QUERY" ]; then
echo ""
exit 0
fi
# Create GraphQL variables
VARIABLES=$(cat <<EOF
{
"query": "$QUERY",
"type": "ANIME",
"per_page": 50,
"genre_not_in": ["Hentai"]
}
fetch_anime_details() {
local anime_id
anime_id=$(echo "$1" | sed -n 's/.*ID: \([0-9]*\).*/\1/p')
if [ -z "$anime_id" ]; then
echo "Select an item to see details..."
return
fi
EOF
)
local query='
query ($id: Int) {
Media(id: $id, type: ANIME) {
title { romaji english }
description(asHtml: false)
genres
meanScore
episodes
status
format
season
seasonYear
studios(isMain: true) { nodes { name } }
}
}
'
local json_payload
json_payload=$(jq -n --arg query "$query" --argjson id "$anime_id" \
'{query: $query, variables: {id: $id}}')
# The GraphQL query is injected here as a properly escaped string
GRAPHQL_QUERY='{GRAPHQL_QUERY}'
# Fetch and format details for the preview window
curl --silent \
--header "Content-Type: application/json" \
--header "Accept: application/json" \
--request POST \
--data "$json_payload" \
https://graphql.anilist.co |
jq -r '
.data.Media |
"Title: \(.title.english // .title.romaji)\n" +
"Score: \(.meanScore // "N/A") | Episodes: \(.episodes // "N/A")\n" +
"Status: \(.status // "N/A") | Format: \(.format // "N/A")\n" +
"Season: \(.season // "N/A") \(.seasonYear // "")\n" +
"Genres: \(.genres | join(", "))\n" +
"Studio: \(.studios.nodes[0].name // "N/A")\n\n" +
"\(.description | gsub("<br><br>"; "\n\n") | gsub("<[^>]*>"; "") | gsub("&quot;"; "\""))"
'
# Create the GraphQL request payload
PAYLOAD=$(cat <<EOF
{
"query": $GRAPHQL_QUERY,
"variables": $VARIABLES
}
EOF
)
# Make the GraphQL request and save raw results
if [ -n "$AUTH_HEADER" ]; then
RESPONSE=$(curl -s -X POST \
-H "Content-Type: application/json" \
-H "Authorization: $AUTH_HEADER" \
-d "$PAYLOAD" \
"$GRAPHQL_ENDPOINT")
else
RESPONSE=$(curl -s -X POST \
-H "Content-Type: application/json" \
-d "$PAYLOAD" \
"$GRAPHQL_ENDPOINT")
fi
# Check if the request was successful
if [ $? -ne 0 ] || [ -z "$RESPONSE" ]; then
echo "❌ Search failed"
exit 1
fi
# Save the raw response for later processing
echo "$RESPONSE" > "$SEARCH_RESULTS_FILE"
# Parse and display results
if command -v jq >/dev/null 2>&1; then
# Use jq for faster and more reliable JSON parsing
echo "$RESPONSE" | jq -r '
if .errors then
"❌ Search error: " + (.errors | tostring)
elif (.data.Page.media // []) | length == 0 then
"❌ No results found"
else
.data.Page.media[] |
((.title.english // .title.romaji // .title.native // "Unknown") +
" (" + (.startDate.year // "Unknown" | tostring) + ") " +
"[" + (.status // "Unknown") + "] - " +
((.genres[:3] // []) | join(", ") | if . == "" then "Unknown" else . end))
end
' 2>/dev/null || echo "❌ Parse error"
else
# Fallback to Python for JSON parsing
echo "$RESPONSE" | python3 -c "
import json
import sys
try:
data = json.load(sys.stdin)
if 'errors' in data:
print('❌ Search error: ' + str(data['errors']))
sys.exit(1)
if 'data' not in data or 'Page' not in data['data'] or 'media' not in data['data']['Page']:
print('❌ No results found')
sys.exit(0)
media_list = data['data']['Page']['media']
if not media_list:
print('❌ No results found')
sys.exit(0)
for media in media_list:
title = media.get('title', {})
english_title = title.get('english') or title.get('romaji') or title.get('native', 'Unknown')
year = media.get('startDate', {}).get('year', 'Unknown') if media.get('startDate') else 'Unknown'
status = media.get('status', 'Unknown')
genres = ', '.join(media.get('genres', [])[:3]) or 'Unknown'
# Format: Title (Year) [Status] - Genres
print(f'{english_title} ({year}) [{status}] - {genres}')
except Exception as e:
print(f'❌ Parse error: {str(e)}')
sys.exit(1)
"
fi

View File

@@ -10,6 +10,7 @@ commands = {
"download": "download.download",
# "downloads": "downloads.downloads",
"auth": "auth.auth",
"stats": "stats.stats",
}

View File

@@ -68,10 +68,10 @@ if TYPE_CHECKING:
epilog=examples.download,
)
@click.option(
"--title",
"-t",
"--title",
"-t",
shell_complete=anime_titles_shell_complete,
help="Title of the anime to search for"
help="Title of the anime to search for",
)
@click.option(
"--episode-range",
@@ -239,7 +239,9 @@ def download(config: AppConfig, **options: "Unpack[DownloadOptions]"):
# Initialize services
feedback.info("Initializing services...")
api_client, provider, selector, media_registry, download_service = _initialize_services(config)
api_client, provider, selector, media_registry, download_service = (
_initialize_services(config)
)
feedback.info(f"Using provider: {provider.__class__.__name__}")
feedback.info(f"Using media API: {config.general.media_api}")
feedback.info(f"Translation type: {config.stream.translation_type}")
@@ -256,16 +258,22 @@ def download(config: AppConfig, **options: "Unpack[DownloadOptions]"):
# Process each selected anime
for selected_anime in selected_anime_list:
feedback.info(f"Processing: {selected_anime.title.english or selected_anime.title.romaji}")
feedback.info(
f"Processing: {selected_anime.title.english or selected_anime.title.romaji}"
)
feedback.info(f"AniList ID: {selected_anime.id}")
# Get available episodes from provider
episodes_result = _get_available_episodes(provider, selected_anime, config, feedback)
episodes_result = _get_available_episodes(
provider, selected_anime, config, feedback
)
if not episodes_result:
feedback.warning(f"No episodes found for {selected_anime.title.english or selected_anime.title.romaji}")
feedback.warning(
f"No episodes found for {selected_anime.title.english or selected_anime.title.romaji}"
)
_suggest_alternatives(selected_anime, provider, config, feedback)
continue
# Unpack the result
if len(episodes_result) == 2:
available_episodes, provider_anime_data = episodes_result
@@ -282,32 +290,51 @@ def download(config: AppConfig, **options: "Unpack[DownloadOptions]"):
feedback.warning("No episodes selected for download")
continue
feedback.info(f"About to download {len(episodes_to_download)} episodes: {', '.join(episodes_to_download)}")
feedback.info(
f"About to download {len(episodes_to_download)} episodes: {', '.join(episodes_to_download)}"
)
# Test stream availability before attempting download (using provider anime data)
if episodes_to_download and provider_anime_data:
test_episode = episodes_to_download[0]
feedback.info(f"Testing stream availability for episode {test_episode}...")
success = _test_episode_stream_availability(provider, provider_anime_data, test_episode, config, feedback)
feedback.info(
f"Testing stream availability for episode {test_episode}..."
)
success = _test_episode_stream_availability(
provider, provider_anime_data, test_episode, config, feedback
)
if not success:
feedback.warning(f"Stream test failed for episode {test_episode}.")
feedback.info("Possible solutions:")
feedback.info("1. Try a different provider (check your config)")
feedback.info("2. Check if the episode number is correct")
feedback.info("3. Try a different translation type (sub/dub)")
feedback.info("4. The anime might not be available on this provider")
feedback.info(
"4. The anime might not be available on this provider"
)
# Ask user if they want to continue anyway
continue_anyway = input("\nContinue with download anyway? (y/N): ").strip().lower()
if continue_anyway not in ['y', 'yes']:
continue_anyway = (
input("\nContinue with download anyway? (y/N): ")
.strip()
.lower()
)
if continue_anyway not in ["y", "yes"]:
feedback.info("Download cancelled by user")
continue
# Download episodes (using provider anime data if available, otherwise AniList data)
anime_for_download = provider_anime_data if provider_anime_data else selected_anime
anime_for_download = (
provider_anime_data if provider_anime_data else selected_anime
)
_download_episodes(
download_service, anime_for_download, episodes_to_download,
quality, force_redownload, max_concurrent, feedback
download_service,
anime_for_download,
episodes_to_download,
quality,
force_redownload,
max_concurrent,
feedback,
)
# Show final statistics
@@ -333,18 +360,36 @@ def _validate_options(options: "DownloadOptions") -> None:
end_date_lesser = options.get("end_date_lesser")
# Score validation
if score_greater is not None and score_lesser is not None and score_greater > score_lesser:
if (
score_greater is not None
and score_lesser is not None
and score_greater > score_lesser
):
raise FastAnimeError("Minimum score cannot be higher than maximum score")
# Popularity validation
if popularity_greater is not None and popularity_lesser is not None and popularity_greater > popularity_lesser:
raise FastAnimeError("Minimum popularity cannot be higher than maximum popularity")
if (
popularity_greater is not None
and popularity_lesser is not None
and popularity_greater > popularity_lesser
):
raise FastAnimeError(
"Minimum popularity cannot be higher than maximum popularity"
)
# Date validation
if start_date_greater is not None and start_date_lesser is not None and start_date_greater > start_date_lesser:
if (
start_date_greater is not None
and start_date_lesser is not None
and start_date_greater > start_date_lesser
):
raise FastAnimeError("Minimum start date cannot be after maximum start date")
if end_date_greater is not None and end_date_lesser is not None and end_date_greater > end_date_lesser:
if (
end_date_greater is not None
and end_date_lesser is not None
and end_date_greater > end_date_lesser
):
raise FastAnimeError("Minimum end date cannot be after maximum end date")
@@ -353,27 +398,47 @@ def _initialize_services(config: AppConfig) -> tuple:
api_client = create_api_client(config.general.media_api, config)
provider = create_provider(config.general.provider)
selector = create_selector(config)
media_registry = MediaRegistryService(config.general.media_api, config.media_registry)
media_registry = MediaRegistryService(
config.general.media_api, config.media_registry
)
download_service = DownloadService(config, media_registry, provider)
return api_client, provider, selector, media_registry, download_service
def _build_search_params(options: "DownloadOptions", config: AppConfig) -> MediaSearchParams:
def _build_search_params(
options: "DownloadOptions", config: AppConfig
) -> MediaSearchParams:
"""Build MediaSearchParams from command options."""
return MediaSearchParams(
query=options.get("title"),
page=options.get("page", 1),
per_page=options.get("per_page") or config.anilist.per_page or 50,
sort=MediaSort(options.get("sort")) if options.get("sort") else None,
status_in=[MediaStatus(s) for s in options.get("status", ())] if options.get("status") else None,
status_not_in=[MediaStatus(s) for s in options.get("status_not", ())] if options.get("status_not") else None,
genre_in=[MediaGenre(g) for g in options.get("genres", ())] if options.get("genres") else None,
genre_not_in=[MediaGenre(g) for g in options.get("genres_not", ())] if options.get("genres_not") else None,
tag_in=[MediaTag(t) for t in options.get("tags", ())] if options.get("tags") else None,
tag_not_in=[MediaTag(t) for t in options.get("tags_not", ())] if options.get("tags_not") else None,
format_in=[MediaFormat(f) for f in options.get("media_format", ())] if options.get("media_format") else None,
type=MediaType(options.get("media_type")) if options.get("media_type") else None,
status_in=[MediaStatus(s) for s in options.get("status", ())]
if options.get("status")
else None,
status_not_in=[MediaStatus(s) for s in options.get("status_not", ())]
if options.get("status_not")
else None,
genre_in=[MediaGenre(g) for g in options.get("genres", ())]
if options.get("genres")
else None,
genre_not_in=[MediaGenre(g) for g in options.get("genres_not", ())]
if options.get("genres_not")
else None,
tag_in=[MediaTag(t) for t in options.get("tags", ())]
if options.get("tags")
else None,
tag_not_in=[MediaTag(t) for t in options.get("tags_not", ())]
if options.get("tags_not")
else None,
format_in=[MediaFormat(f) for f in options.get("media_format", ())]
if options.get("media_format")
else None,
type=MediaType(options.get("media_type"))
if options.get("media_type")
else None,
season=MediaSeason(options.get("season")) if options.get("season") else None,
seasonYear=int(year) if (year := options.get("year")) else None,
popularity_greater=options.get("popularity_greater"),
@@ -393,20 +458,24 @@ def _search_anime(api_client, search_params, feedback):
from rich.progress import Progress, SpinnerColumn, TextColumn
# Check if we have any search criteria at all
has_criteria = any([
search_params.query,
search_params.genre_in,
search_params.tag_in,
search_params.status_in,
search_params.season,
search_params.seasonYear,
search_params.format_in,
search_params.popularity_greater,
search_params.averageScore_greater,
])
has_criteria = any(
[
search_params.query,
search_params.genre_in,
search_params.tag_in,
search_params.status_in,
search_params.season,
search_params.seasonYear,
search_params.format_in,
search_params.popularity_greater,
search_params.averageScore_greater,
]
)
if not has_criteria:
raise FastAnimeError("Please provide at least one search criterion (title, genre, tag, status, etc.)")
raise FastAnimeError(
"Please provide at least one search criterion (title, genre, tag, status, etc.)"
)
with Progress(
SpinnerColumn(),
@@ -426,7 +495,9 @@ def _select_anime(search_result, selector, feedback):
"""Let user select anime from search results."""
if len(search_result.media) == 1:
selected_anime = search_result.media[0]
feedback.info(f"Auto-selected: {selected_anime.title.english or selected_anime.title.romaji}")
feedback.info(
f"Auto-selected: {selected_anime.title.english or selected_anime.title.romaji}"
)
return [selected_anime]
# Create choice strings with additional info
@@ -467,41 +538,53 @@ def _get_available_episodes(provider, anime, config, feedback):
try:
# Search for anime in provider first
media_title = anime.title.english or anime.title.romaji
feedback.info(f"Searching provider '{provider.__class__.__name__}' for: '{media_title}'")
feedback.info(
f"Searching provider '{provider.__class__.__name__}' for: '{media_title}'"
)
feedback.info(f"Using translation type: '{config.stream.translation_type}'")
provider_search_results = provider.search(
SearchParams(query=media_title, translation_type=config.stream.translation_type)
SearchParams(
query=media_title, translation_type=config.stream.translation_type
)
)
if not provider_search_results or not provider_search_results.results:
feedback.warning(f"Could not find '{media_title}' on provider '{provider.__class__.__name__}'")
feedback.warning(
f"Could not find '{media_title}' on provider '{provider.__class__.__name__}'"
)
return []
feedback.info(f"Found {len(provider_search_results.results)} results on provider")
feedback.info(
f"Found {len(provider_search_results.results)} results on provider"
)
# Show the first few results for debugging
for i, result in enumerate(provider_search_results.results[:3]):
feedback.info(f"Result {i+1}: ID={result.id}, Title='{getattr(result, 'title', 'Unknown')}'")
feedback.info(
f"Result {i + 1}: ID={result.id}, Title='{getattr(result, 'title', 'Unknown')}'"
)
# Get the first result (could be enhanced with fuzzy matching)
first_result = provider_search_results.results[0]
feedback.info(f"Using first result: ID={first_result.id}")
# Now get the full anime data using the PROVIDER'S ID, not AniList ID
provider_anime_data = provider.get(
AnimeParams(id=first_result.id, query=media_title)
)
if not provider_anime_data:
feedback.warning(f"Failed to get anime details from provider")
feedback.warning("Failed to get anime details from provider")
return []
# Check all available translation types
translation_types = ['sub', 'dub']
translation_types = ["sub", "dub"]
for trans_type in translation_types:
episodes = getattr(provider_anime_data.episodes, trans_type, [])
feedback.info(f"Translation '{trans_type}': {len(episodes)} episodes available")
feedback.info(
f"Translation '{trans_type}': {len(episodes)} episodes available"
)
available_episodes = getattr(
provider_anime_data.episodes, config.stream.translation_type, []
@@ -512,33 +595,46 @@ def _get_available_episodes(provider, anime, config, feedback):
# Suggest alternative translation type if available
for trans_type in translation_types:
if trans_type != config.stream.translation_type:
other_episodes = getattr(provider_anime_data.episodes, trans_type, [])
other_episodes = getattr(
provider_anime_data.episodes, trans_type, []
)
if other_episodes:
feedback.info(f"Suggestion: Try using translation type '{trans_type}' (has {len(other_episodes)} episodes)")
feedback.info(
f"Suggestion: Try using translation type '{trans_type}' (has {len(other_episodes)} episodes)"
)
return []
feedback.info(f"Found {len(available_episodes)} episodes available for download")
feedback.info(
f"Found {len(available_episodes)} episodes available for download"
)
# Return both episodes and the provider anime data for later use
return available_episodes, provider_anime_data
except Exception as e:
feedback.error(f"Error getting episodes from provider: {e}")
import traceback
feedback.error("Full traceback", traceback.format_exc())
return []
def _determine_episodes_to_download(episode_range, available_episodes, selector, feedback):
def _determine_episodes_to_download(
episode_range, available_episodes, selector, feedback
):
"""Determine which episodes to download based on range or user selection."""
if not available_episodes:
feedback.warning("No episodes available to download")
return []
if episode_range:
try:
episodes_to_download = list(parse_episode_range(episode_range, available_episodes))
feedback.info(f"Episodes from range '{episode_range}': {', '.join(episodes_to_download)}")
episodes_to_download = list(
parse_episode_range(episode_range, available_episodes)
)
feedback.info(
f"Episodes from range '{episode_range}': {', '.join(episodes_to_download)}"
)
return episodes_to_download
except (ValueError, IndexError) as e:
feedback.error(f"Invalid episode range '{episode_range}': {e}")
@@ -550,10 +646,10 @@ def _determine_episodes_to_download(episode_range, available_episodes, selector,
choices=available_episodes,
header="Use TAB to select multiple episodes, ENTER to confirm",
)
if selected_episodes:
feedback.info(f"Selected episodes: {', '.join(selected_episodes)}")
return selected_episodes
@@ -563,13 +659,17 @@ def _suggest_alternatives(anime, provider, config, feedback):
feedback.info(f"1. Current provider: {provider.__class__.__name__}")
feedback.info(f"2. AniList ID being used: {anime.id}")
feedback.info(f"3. Translation type: {config.stream.translation_type}")
# Special message for AllAnime provider
if provider.__class__.__name__ == "AllAnimeProvider":
feedback.info("4. AllAnime ID mismatch: AllAnime uses different IDs than AniList")
feedback.info(
"4. AllAnime ID mismatch: AllAnime uses different IDs than AniList"
)
feedback.info(" The provider searches by title, but episodes use AniList ID")
feedback.info(" This can cause episodes to not be found even if the anime exists")
feedback.info(
" This can cause episodes to not be found even if the anime exists"
)
# Check if provider has different ID mapping
anime_titles = []
if anime.title.english:
@@ -578,7 +678,7 @@ def _suggest_alternatives(anime, provider, config, feedback):
anime_titles.append(anime.title.romaji)
if anime.title.native:
anime_titles.append(anime.title.native)
feedback.info(f"5. Available titles: {', '.join(anime_titles)}")
feedback.info("6. Possible solutions:")
feedback.info(" - Try a different provider (GogoAnime, 9anime, etc.)")
@@ -588,7 +688,15 @@ def _suggest_alternatives(anime, provider, config, feedback):
feedback.info(" - Check if anime is available in your region")
def _download_episodes(download_service, anime, episodes, quality, force_redownload, max_concurrent, feedback):
def _download_episodes(
download_service,
anime,
episodes,
quality,
force_redownload,
max_concurrent,
feedback,
):
"""Download the specified episodes."""
from concurrent.futures import ThreadPoolExecutor, as_completed
from rich.console import Console
@@ -607,18 +715,19 @@ def _download_episodes(download_service, anime, episodes, quality, force_redownl
anime_title = anime.title.english or anime.title.romaji
console.print(f"\n[bold green]Starting downloads for: {anime_title}[/bold green]")
# Set up logging capture to get download errors
log_messages = []
class ListHandler(logging.Handler):
def emit(self, record):
log_messages.append(self.format(record))
handler = ListHandler()
handler.setLevel(logging.ERROR)
logger = logging.getLogger('fastanime')
logger = logging.getLogger("fastanime")
logger.addHandler(handler)
try:
with Progress(
SpinnerColumn(),
@@ -628,18 +737,19 @@ def _download_episodes(download_service, anime, episodes, quality, force_redownl
TaskProgressColumn(),
TimeElapsedColumn(),
) as progress:
task = progress.add_task("Downloading episodes...", total=len(episodes))
if max_concurrent == 1:
# Sequential downloads
results = {}
for episode in episodes:
progress.update(task, description=f"Downloading episode {episode}...")
progress.update(
task, description=f"Downloading episode {episode}..."
)
# Clear previous log messages for this episode
log_messages.clear()
try:
success = download_service.download_episode(
media_item=anime,
@@ -648,19 +758,26 @@ def _download_episodes(download_service, anime, episodes, quality, force_redownl
force_redownload=force_redownload,
)
results[episode] = success
if not success:
# Try to get more detailed error from registry
error_msg = _get_episode_error_details(download_service, anime, episode)
error_msg = _get_episode_error_details(
download_service, anime, episode
)
if error_msg:
feedback.error(f"Episode {episode}", error_msg)
elif log_messages:
# Show any log messages that were captured
for msg in log_messages[-3:]: # Show last 3 error messages
for msg in log_messages[
-3:
]: # Show last 3 error messages
feedback.error(f"Episode {episode}", msg)
else:
feedback.error(f"Episode {episode}", "Download failed - check logs for details")
feedback.error(
f"Episode {episode}",
"Download failed - check logs for details",
)
except Exception as e:
results[episode] = False
feedback.error(f"Episode {episode} failed", str(e))
@@ -681,7 +798,7 @@ def _download_episodes(download_service, anime, episodes, quality, force_redownl
): episode
for episode in episodes
}
# Process completed downloads
for future in as_completed(future_to_episode):
episode = future_to_episode[future]
@@ -690,15 +807,22 @@ def _download_episodes(download_service, anime, episodes, quality, force_redownl
results[episode] = success
if not success:
# Try to get more detailed error from registry
error_msg = _get_episode_error_details(download_service, anime, episode)
error_msg = _get_episode_error_details(
download_service, anime, episode
)
if error_msg:
feedback.error(f"Episode {episode}", error_msg)
else:
feedback.error(f"Episode {episode}", "Download failed - check logs for details")
feedback.error(
f"Episode {episode}",
"Download failed - check logs for details",
)
except Exception as e:
results[episode] = False
feedback.error(f"Download failed for episode {episode}", str(e))
feedback.error(
f"Download failed for episode {episode}", str(e)
)
progress.advance(task)
finally:
# Remove the log handler
@@ -715,13 +839,13 @@ def _get_episode_error_details(download_service, anime, episode_number):
media_record = download_service.media_registry.get_record(anime.id)
if not media_record:
return None
# Find the episode in the record
for episode_record in media_record.episodes:
if episode_record.episode_number == episode_number:
if episode_record.error_message:
error_msg = episode_record.error_message
# Provide more helpful error messages for common issues
if "Failed to get server for episode" in error_msg:
return f"Episode {episode_number} not available on current provider. Try a different provider or check episode number."
@@ -732,20 +856,24 @@ def _get_episode_error_details(download_service, anime, episode_number):
elif episode_record.download_status:
return f"Download status: {episode_record.download_status.value}"
break
return None
except Exception:
return None
def _test_episode_stream_availability(provider, anime, episode_number, config, feedback):
def _test_episode_stream_availability(
provider, anime, episode_number, config, feedback
):
"""Test if streams are available for a specific episode."""
try:
from .....libs.provider.anime.params import EpisodeStreamsParams
media_title = anime.title.english or anime.title.romaji
feedback.info(f"Testing stream availability for '{media_title}' episode {episode_number}")
feedback.info(
f"Testing stream availability for '{media_title}' episode {episode_number}"
)
# Test episode streams
streams = provider.episode_streams(
EpisodeStreamsParams(
@@ -755,29 +883,39 @@ def _test_episode_stream_availability(provider, anime, episode_number, config, f
translation_type=config.stream.translation_type,
)
)
if not streams:
feedback.warning(f"No streams found for episode {episode_number}")
return False
# Convert to list to check actual availability
stream_list = list(streams)
if not stream_list:
feedback.warning(f"No stream servers available for episode {episode_number}")
feedback.warning(
f"No stream servers available for episode {episode_number}"
)
return False
feedback.info(f"Found {len(stream_list)} stream server(s) for episode {episode_number}")
feedback.info(
f"Found {len(stream_list)} stream server(s) for episode {episode_number}"
)
# Show details about the first server for debugging
first_server = stream_list[0]
feedback.info(f"First server: name='{first_server.name}', type='{type(first_server).__name__}'")
feedback.info(
f"First server: name='{first_server.name}', type='{type(first_server).__name__}'"
)
return True
except TypeError as e:
if "'NoneType' object is not subscriptable" in str(e):
feedback.warning(f"Episode {episode_number} not available on provider (API returned null)")
feedback.info("This usually means the episode doesn't exist on this provider or isn't accessible")
feedback.warning(
f"Episode {episode_number} not available on provider (API returned null)"
)
feedback.info(
"This usually means the episode doesn't exist on this provider or isn't accessible"
)
return False
else:
feedback.error(f"Type error testing stream availability: {e}")
@@ -785,6 +923,7 @@ def _test_episode_stream_availability(provider, anime, episode_number, config, f
except Exception as e:
feedback.error(f"Error testing stream availability: {e}")
import traceback
feedback.error("Stream test traceback", traceback.format_exc())
return False
@@ -793,25 +932,31 @@ def _display_download_results(console, results: dict[str, bool], anime):
"""Display download results in a formatted table."""
from rich.table import Table
table = Table(title=f"Download Results for {anime.title.english or anime.title.romaji}")
table = Table(
title=f"Download Results for {anime.title.english or anime.title.romaji}"
)
table.add_column("Episode", justify="center", style="cyan")
table.add_column("Status", justify="center")
for episode, success in sorted(results.items(), key=lambda x: float(x[0])):
status = "[green]✓ Success[/green]" if success else "[red]✗ Failed[/red]"
table.add_row(episode, status)
console.print(table)
# Summary
total = len(results)
successful = sum(results.values())
failed = total - successful
if failed == 0:
console.print(f"\n[bold green]All {total} episodes downloaded successfully![/bold green]")
console.print(
f"\n[bold green]All {total} episodes downloaded successfully![/bold green]"
)
else:
console.print(f"\n[yellow]Download complete: {successful}/{total} successful, {failed} failed[/yellow]")
console.print(
f"\n[yellow]Download complete: {successful}/{total} successful, {failed} failed[/yellow]"
)
def _show_final_statistics(download_service, feedback):
@@ -820,17 +965,17 @@ def _show_final_statistics(download_service, feedback):
console = Console()
stats = download_service.get_download_statistics()
if stats:
console.print(f"\n[bold blue]Overall Download Statistics:[/bold blue]")
console.print("\n[bold blue]Overall Download Statistics:[/bold blue]")
console.print(f"Total episodes tracked: {stats.get('total_episodes', 0)}")
console.print(f"Successfully downloaded: {stats.get('downloaded', 0)}")
console.print(f"Failed downloads: {stats.get('failed', 0)}")
console.print(f"Queued downloads: {stats.get('queued', 0)}")
if stats.get('total_size_bytes', 0) > 0:
size_mb = stats['total_size_bytes'] / (1024 * 1024)
if stats.get("total_size_bytes", 0) > 0:
size_mb = stats["total_size_bytes"] / (1024 * 1024)
if size_mb > 1024:
console.print(f"Total size: {size_mb/1024:.2f} GB")
console.print(f"Total size: {size_mb / 1024:.2f} GB")
else:
console.print(f"Total size: {size_mb:.2f} MB")

View File

@@ -229,17 +229,39 @@ def search(config: AppConfig, **options: "Unpack[SearchOptions]"):
on_list = options.get("on_list")
# Validate logical relationships
if score_greater is not None and score_lesser is not None and score_greater > score_lesser:
if (
score_greater is not None
and score_lesser is not None
and score_greater > score_lesser
):
raise FastAnimeError("Minimum score cannot be higher than maximum score")
if popularity_greater is not None and popularity_lesser is not None and popularity_greater > popularity_lesser:
raise FastAnimeError("Minimum popularity cannot be higher than maximum popularity")
if start_date_greater is not None and start_date_lesser is not None and start_date_greater > start_date_lesser:
raise FastAnimeError("Start date greater cannot be later than start date lesser")
if end_date_greater is not None and end_date_lesser is not None and end_date_greater > end_date_lesser:
raise FastAnimeError("End date greater cannot be later than end date lesser")
if (
popularity_greater is not None
and popularity_lesser is not None
and popularity_greater > popularity_lesser
):
raise FastAnimeError(
"Minimum popularity cannot be higher than maximum popularity"
)
if (
start_date_greater is not None
and start_date_lesser is not None
and start_date_greater > start_date_lesser
):
raise FastAnimeError(
"Start date greater cannot be later than start date lesser"
)
if (
end_date_greater is not None
and end_date_lesser is not None
and end_date_greater > end_date_lesser
):
raise FastAnimeError(
"End date greater cannot be later than end date lesser"
)
# Build search parameters
search_params = MediaSearchParams(
@@ -287,7 +309,7 @@ def search(config: AppConfig, **options: "Unpack[SearchOptions]"):
feedback.info(
f"Found {len(search_result.media)} anime matching your search. Launching interactive mode..."
)
# Create initial state with search results
initial_state = State(
menu_name=MenuName.RESULTS,
@@ -299,7 +321,7 @@ def search(config: AppConfig, **options: "Unpack[SearchOptions]"):
page_info=search_result.page_info,
),
)
session.load_menus_from_folder("media")
session.run(config, history=[initial_state])

View File

@@ -16,7 +16,6 @@ def stats(config: "AppConfig"):
from rich.markdown import Markdown
from rich.panel import Panel
from .....core.exceptions import FastAnimeError
from .....libs.media_api.api import create_api_client
from ....service.auth import AuthService
from ....service.feedback import FeedbackService
@@ -93,4 +92,4 @@ def stats(config: "AppConfig"):
raise click.Abort()
except Exception as e:
feedback.error("Unexpected error occurred", str(e))
raise click.Abort()
raise click.Abort()

View File

@@ -150,20 +150,19 @@ def download(config: AppConfig, **options: "Unpack[Options]"):
if not anime:
raise FastAnimeError(f"Failed to fetch anime {anime_result.title}")
available_episodes: list[str] = sorted(
getattr(anime.episodes, config.stream.translation_type), key=float
)
if options["episode_range"]:
from ..utils.parser import parse_episode_range
try:
episodes_range = parse_episode_range(
options["episode_range"],
available_episodes
options["episode_range"], available_episodes
)
for episode in episodes_range:
download_anime(
config, options, provider, selector, anime, anime_title, episode

View File

@@ -2,7 +2,6 @@
Registry backup command - create full backups of the registry
"""
import shutil
import tarfile
from pathlib import Path
from datetime import datetime
@@ -19,31 +18,22 @@ from ....utils.feedback import create_feedback_manager
"--output",
"-o",
type=click.Path(),
help="Output backup file path (auto-generated if not specified)"
)
@click.option(
"--compress",
"-c",
is_flag=True,
help="Compress the backup archive"
)
@click.option(
"--include-cache",
is_flag=True,
help="Include cache files in backup"
help="Output backup file path (auto-generated if not specified)",
)
@click.option("--compress", "-c", is_flag=True, help="Compress the backup archive")
@click.option("--include-cache", is_flag=True, help="Include cache files in backup")
@click.option(
"--format",
"backup_format",
type=click.Choice(["tar", "zip"], case_sensitive=False),
default="tar",
help="Backup archive format"
help="Backup archive format",
)
@click.option(
"--api",
default="anilist",
type=click.Choice(["anilist"], case_sensitive=False),
help="Media API registry to backup"
help="Media API registry to backup",
)
@click.pass_obj
def backup(
@@ -52,35 +42,37 @@ def backup(
compress: bool,
include_cache: bool,
backup_format: str,
api: str
api: str,
):
"""
Create a complete backup of your media registry.
Includes all media records, index files, and optionally cache data.
Backups can be compressed and are suitable for restoration.
"""
feedback = create_feedback_manager(config.general.icons)
try:
registry_service = MediaRegistryService(api, config.registry)
# Generate output filename if not specified
if not output:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
extension = "tar.gz" if compress and backup_format == "tar" else backup_format
extension = (
"tar.gz" if compress and backup_format == "tar" else backup_format
)
if backup_format == "zip":
extension = "zip"
output = f"fastanime_registry_backup_{api}_{timestamp}.{extension}"
output_path = Path(output)
# Get backup statistics before starting
stats = registry_service.get_registry_stats()
total_media = stats.get('total_media', 0)
total_media = stats.get("total_media", 0)
feedback.info("Starting Backup", f"Backing up {total_media} media entries...")
# Create backup based on format
if backup_format.lower() == "tar":
_create_tar_backup(
@@ -90,101 +82,111 @@ def backup(
_create_zip_backup(
registry_service, output_path, include_cache, feedback, api
)
# Get final backup size
backup_size = _format_file_size(output_path)
feedback.success(
"Backup Complete",
f"Registry backed up to {output_path} ({backup_size})"
"Backup Complete", f"Registry backed up to {output_path} ({backup_size})"
)
# Show backup contents summary
_show_backup_summary(output_path, backup_format, feedback)
except Exception as e:
feedback.error("Backup Error", f"Failed to create backup: {e}")
raise click.Abort()
def _create_tar_backup(registry_service, output_path: Path, compress: bool, include_cache: bool, feedback, api: str):
def _create_tar_backup(
registry_service,
output_path: Path,
compress: bool,
include_cache: bool,
feedback,
api: str,
):
"""Create a tar-based backup."""
mode = "w:gz" if compress else "w"
with tarfile.open(output_path, mode) as tar:
# Add registry directory
registry_dir = registry_service.config.media_dir / api
if registry_dir.exists():
tar.add(registry_dir, arcname=f"registry/{api}")
feedback.info("Added to backup", f"Registry data ({api})")
# Add index directory
index_dir = registry_service.config.index_dir
if index_dir.exists():
tar.add(index_dir, arcname="index")
feedback.info("Added to backup", "Registry index")
# Add cache if requested
if include_cache:
cache_dir = registry_service.config.media_dir.parent / "cache"
if cache_dir.exists():
tar.add(cache_dir, arcname="cache")
feedback.info("Added to backup", "Cache data")
# Add metadata file
metadata = _create_backup_metadata(registry_service, api, include_cache)
metadata_path = output_path.parent / "backup_metadata.json"
try:
import json
with open(metadata_path, 'w', encoding='utf-8') as f:
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, default=str)
tar.add(metadata_path, arcname="backup_metadata.json")
metadata_path.unlink() # Clean up temp file
except Exception as e:
feedback.warning("Metadata Error", f"Failed to add metadata: {e}")
def _create_zip_backup(registry_service, output_path: Path, include_cache: bool, feedback, api: str):
def _create_zip_backup(
registry_service, output_path: Path, include_cache: bool, feedback, api: str
):
"""Create a zip-based backup."""
import zipfile
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zip_file:
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zip_file:
# Add registry directory
registry_dir = registry_service.config.media_dir / api
if registry_dir.exists():
for file_path in registry_dir.rglob('*'):
for file_path in registry_dir.rglob("*"):
if file_path.is_file():
arcname = f"registry/{api}/{file_path.relative_to(registry_dir)}"
zip_file.write(file_path, arcname)
feedback.info("Added to backup", f"Registry data ({api})")
# Add index directory
index_dir = registry_service.config.index_dir
if index_dir.exists():
for file_path in index_dir.rglob('*'):
for file_path in index_dir.rglob("*"):
if file_path.is_file():
arcname = f"index/{file_path.relative_to(index_dir)}"
zip_file.write(file_path, arcname)
feedback.info("Added to backup", "Registry index")
# Add cache if requested
if include_cache:
cache_dir = registry_service.config.media_dir.parent / "cache"
if cache_dir.exists():
for file_path in cache_dir.rglob('*'):
for file_path in cache_dir.rglob("*"):
if file_path.is_file():
arcname = f"cache/{file_path.relative_to(cache_dir)}"
zip_file.write(file_path, arcname)
feedback.info("Added to backup", "Cache data")
# Add metadata
metadata = _create_backup_metadata(registry_service, api, include_cache)
try:
import json
metadata_json = json.dumps(metadata, indent=2, default=str)
zip_file.writestr("backup_metadata.json", metadata_json)
except Exception as e:
@@ -194,13 +196,13 @@ def _create_zip_backup(registry_service, output_path: Path, include_cache: bool,
def _create_backup_metadata(registry_service, api: str, include_cache: bool) -> dict:
"""Create backup metadata."""
stats = registry_service.get_registry_stats()
return {
"backup_timestamp": datetime.now().isoformat(),
"fastanime_version": "unknown", # You might want to get this from somewhere
"registry_version": stats.get('version'),
"registry_version": stats.get("version"),
"api": api,
"total_media": stats.get('total_media', 0),
"total_media": stats.get("total_media", 0),
"include_cache": include_cache,
"registry_stats": stats,
"backup_type": "full",
@@ -209,22 +211,23 @@ def _create_backup_metadata(registry_service, api: str, include_cache: bool) ->
def _show_backup_summary(backup_path: Path, format_type: str, feedback):
"""Show summary of backup contents."""
try:
if format_type.lower() == "tar":
with tarfile.open(backup_path, 'r:*') as tar:
with tarfile.open(backup_path, "r:*") as tar:
members = tar.getmembers()
file_count = len([m for m in members if m.isfile()])
dir_count = len([m for m in members if m.isdir()])
else: # zip
import zipfile
with zipfile.ZipFile(backup_path, 'r') as zip_file:
with zipfile.ZipFile(backup_path, "r") as zip_file:
info_list = zip_file.infolist()
file_count = len([info for info in info_list if not info.is_dir()])
dir_count = len([info for info in info_list if info.is_dir()])
feedback.info("Backup Contents", f"{file_count} files, {dir_count} directories")
except Exception as e:
feedback.warning("Summary Error", f"Could not analyze backup contents: {e}")
@@ -233,7 +236,7 @@ def _format_file_size(file_path: Path) -> str:
"""Format file size in human-readable format."""
try:
size = file_path.stat().st_size
for unit in ['B', 'KB', 'MB', 'GB']:
for unit in ["B", "KB", "MB", "GB"]:
if size < 1024.0:
return f"{size:.1f} {unit}"
size /= 1024.0

View File

@@ -13,41 +13,26 @@ from ....utils.feedback import create_feedback_manager
@click.command(help="Clean up orphaned entries and invalid data from registry")
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be cleaned without making changes"
"--dry-run", is_flag=True, help="Show what would be cleaned without making changes"
)
@click.option(
"--orphaned",
is_flag=True,
help="Remove orphaned media records (index entries without files)"
help="Remove orphaned media records (index entries without files)",
)
@click.option("--invalid", is_flag=True, help="Remove invalid or corrupted entries")
@click.option("--duplicates", is_flag=True, help="Remove duplicate entries")
@click.option(
"--old-format", is_flag=True, help="Clean entries from old registry format versions"
)
@click.option(
"--invalid",
is_flag=True,
help="Remove invalid or corrupted entries"
)
@click.option(
"--duplicates",
is_flag=True,
help="Remove duplicate entries"
)
@click.option(
"--old-format",
is_flag=True,
help="Clean entries from old registry format versions"
)
@click.option(
"--force",
"-f",
is_flag=True,
help="Force cleanup without confirmation prompts"
"--force", "-f", is_flag=True, help="Force cleanup without confirmation prompts"
)
@click.option(
"--api",
default="anilist",
type=click.Choice(["anilist"], case_sensitive=False),
help="Media API registry to clean"
help="Media API registry to clean",
)
@click.pass_obj
def clean(
@@ -58,73 +43,86 @@ def clean(
duplicates: bool,
old_format: bool,
force: bool,
api: str
api: str,
):
"""
Clean up your local media registry.
Can remove orphaned entries, invalid data, duplicates, and entries
from old format versions. Use --dry-run to preview changes.
"""
feedback = create_feedback_manager(config.general.icons)
console = Console()
# Default to all cleanup types if none specified
if not any([orphaned, invalid, duplicates, old_format]):
orphaned = invalid = duplicates = old_format = True
try:
registry_service = MediaRegistryService(api, config.registry)
cleanup_results = {
"orphaned": [],
"invalid": [],
"duplicates": [],
"old_format": []
"old_format": [],
}
# Analyze registry for cleanup opportunities
_analyze_registry(registry_service, cleanup_results, orphaned, invalid, duplicates, old_format)
_analyze_registry(
registry_service, cleanup_results, orphaned, invalid, duplicates, old_format
)
# Show cleanup summary
_display_cleanup_summary(console, cleanup_results, config.general.icons)
# Confirm cleanup if not dry run and not forced
total_items = sum(len(items) for items in cleanup_results.values())
if total_items == 0:
feedback.info("Registry Clean", "No cleanup needed - registry is already clean!")
feedback.info(
"Registry Clean", "No cleanup needed - registry is already clean!"
)
return
if not dry_run:
if not force:
if not click.confirm(f"Clean up {total_items} items from registry?"):
feedback.info("Cleanup Cancelled", "No changes were made")
return
# Perform cleanup
_perform_cleanup(registry_service, cleanup_results, feedback)
feedback.success("Cleanup Complete", f"Cleaned up {total_items} items from registry")
feedback.success(
"Cleanup Complete", f"Cleaned up {total_items} items from registry"
)
else:
feedback.info("Dry Run Complete", f"Would clean up {total_items} items")
except Exception as e:
feedback.error("Cleanup Error", f"Failed to clean registry: {e}")
raise click.Abort()
def _analyze_registry(registry_service, results: dict, check_orphaned: bool, check_invalid: bool, check_duplicates: bool, check_old_format: bool):
def _analyze_registry(
registry_service,
results: dict,
check_orphaned: bool,
check_invalid: bool,
check_duplicates: bool,
check_old_format: bool,
):
"""Analyze registry for cleanup opportunities."""
if check_orphaned:
results["orphaned"] = _find_orphaned_entries(registry_service)
if check_invalid:
results["invalid"] = _find_invalid_entries(registry_service)
if check_duplicates:
results["duplicates"] = _find_duplicate_entries(registry_service)
if check_old_format:
results["old_format"] = _find_old_format_entries(registry_service)
@@ -132,65 +130,77 @@ def _analyze_registry(registry_service, results: dict, check_orphaned: bool, che
def _find_orphaned_entries(registry_service) -> list:
"""Find index entries that don't have corresponding media files."""
orphaned = []
try:
index = registry_service._load_index()
for entry_key, entry in index.media_index.items():
media_file = registry_service._get_media_file_path(entry.media_id)
if not media_file.exists():
orphaned.append({
"type": "orphaned_index",
"id": entry.media_id,
"key": entry_key,
"reason": "Media file missing"
})
orphaned.append(
{
"type": "orphaned_index",
"id": entry.media_id,
"key": entry_key,
"reason": "Media file missing",
}
)
except Exception:
pass
return orphaned
def _find_invalid_entries(registry_service) -> list:
"""Find invalid or corrupted entries."""
invalid = []
try:
# Check all media files
for media_file in registry_service.media_registry_dir.iterdir():
if not media_file.name.endswith('.json'):
if not media_file.name.endswith(".json"):
continue
try:
media_id = int(media_file.stem)
record = registry_service.get_media_record(media_id)
# Check for invalid record structure
if not record or not record.media_item:
invalid.append({
"type": "invalid_record",
"id": media_id,
"file": media_file,
"reason": "Invalid record structure"
})
elif not record.media_item.title or not record.media_item.title.english and not record.media_item.title.romaji:
invalid.append({
"type": "invalid_title",
"id": media_id,
"file": media_file,
"reason": "Missing or invalid title"
})
invalid.append(
{
"type": "invalid_record",
"id": media_id,
"file": media_file,
"reason": "Invalid record structure",
}
)
elif (
not record.media_item.title
or not record.media_item.title.english
and not record.media_item.title.romaji
):
invalid.append(
{
"type": "invalid_title",
"id": media_id,
"file": media_file,
"reason": "Missing or invalid title",
}
)
except (ValueError, Exception) as e:
invalid.append({
"type": "corrupted_file",
"id": media_file.stem,
"file": media_file,
"reason": f"File corruption: {e}"
})
invalid.append(
{
"type": "corrupted_file",
"id": media_file.stem,
"file": media_file,
"reason": f"File corruption: {e}",
}
)
except Exception:
pass
return invalid
@@ -198,76 +208,81 @@ def _find_duplicate_entries(registry_service) -> list:
"""Find duplicate entries (same media ID appearing multiple times)."""
duplicates = []
seen_ids = set()
try:
index = registry_service._load_index()
for entry_key, entry in index.media_index.items():
if entry.media_id in seen_ids:
duplicates.append({
"type": "duplicate_index",
"id": entry.media_id,
"key": entry_key,
"reason": "Duplicate media ID in index"
})
duplicates.append(
{
"type": "duplicate_index",
"id": entry.media_id,
"key": entry_key,
"reason": "Duplicate media ID in index",
}
)
else:
seen_ids.add(entry.media_id)
except Exception:
pass
return duplicates
def _find_old_format_entries(registry_service) -> list:
"""Find entries from old registry format versions."""
old_format = []
try:
index = registry_service._load_index()
current_version = registry_service._index.version
# Check for entries that might be from old formats
# This is a placeholder - you'd implement specific checks based on your version history
for media_file in registry_service.media_registry_dir.iterdir():
if not media_file.name.endswith('.json'):
if not media_file.name.endswith(".json"):
continue
try:
import json
with open(media_file, 'r') as f:
with open(media_file, "r") as f:
data = json.load(f)
# Check for old format indicators
if 'version' in data and data['version'] < current_version:
old_format.append({
"type": "old_version",
"id": media_file.stem,
"file": media_file,
"reason": f"Old format version {data.get('version')}"
})
if "version" in data and data["version"] < current_version:
old_format.append(
{
"type": "old_version",
"id": media_file.stem,
"file": media_file,
"reason": f"Old format version {data.get('version')}",
}
)
except Exception:
pass
except Exception:
pass
return old_format
def _display_cleanup_summary(console: Console, results: dict, icons: bool):
"""Display summary of cleanup opportunities."""
table = Table(title=f"{'🧹 ' if icons else ''}Registry Cleanup Summary")
table.add_column("Category", style="cyan", no_wrap=True)
table.add_column("Count", style="magenta", justify="right")
table.add_column("Description", style="white")
categories = {
"orphaned": "Orphaned Entries",
"invalid": "Invalid Entries",
"invalid": "Invalid Entries",
"duplicates": "Duplicate Entries",
"old_format": "Old Format Entries"
"old_format": "Old Format Entries",
}
for category, display_name in categories.items():
count = len(results[category])
if count > 0:
@@ -278,52 +293,50 @@ def _display_cleanup_summary(console: Console, results: dict, icons: bool):
description += "..."
else:
description = "None found"
table.add_row(display_name, str(count), description)
console.print(table)
console.print()
# Show detailed breakdown if there are items to clean
for category, items in results.items():
if items:
_display_category_details(console, category, items, icons)
def _display_category_details(console: Console, category: str, items: list, icons: bool):
def _display_category_details(
console: Console, category: str, items: list, icons: bool
):
"""Display detailed breakdown for a cleanup category."""
category_names = {
"orphaned": "🔗 Orphaned Entries" if icons else "Orphaned Entries",
"invalid": "❌ Invalid Entries" if icons else "Invalid Entries",
"duplicates": "👥 Duplicate Entries" if icons else "Duplicate Entries",
"old_format": "📼 Old Format Entries" if icons else "Old Format Entries"
"duplicates": "👥 Duplicate Entries" if icons else "Duplicate Entries",
"old_format": "📼 Old Format Entries" if icons else "Old Format Entries",
}
table = Table(title=category_names.get(category, category.title()))
table.add_column("ID", style="cyan", no_wrap=True)
table.add_column("Type", style="magenta")
table.add_column("Reason", style="yellow")
for item in items[:10]: # Show max 10 items
table.add_row(
str(item["id"]),
item["type"],
item["reason"]
)
table.add_row(str(item["id"]), item["type"], item["reason"])
if len(items) > 10:
table.add_row("...", "...", f"And {len(items) - 10} more")
console.print(table)
console.print()
def _perform_cleanup(registry_service, results: dict, feedback):
"""Perform the actual cleanup operations."""
cleaned_count = 0
# Clean orphaned entries
for item in results["orphaned"]:
try:
@@ -334,25 +347,29 @@ def _perform_cleanup(registry_service, results: dict, feedback):
registry_service._save_index(index)
cleaned_count += 1
except Exception as e:
feedback.warning("Cleanup Error", f"Failed to clean orphaned entry {item['id']}: {e}")
feedback.warning(
"Cleanup Error", f"Failed to clean orphaned entry {item['id']}: {e}"
)
# Clean invalid entries
for item in results["invalid"]:
try:
if "file" in item:
item["file"].unlink() # Delete the file
cleaned_count += 1
# Also remove from index if present
index = registry_service._load_index()
entry_key = f"{registry_service._media_api}_{item['id']}"
if entry_key in index.media_index:
del index.media_index[entry_key]
registry_service._save_index(index)
except Exception as e:
feedback.warning("Cleanup Error", f"Failed to clean invalid entry {item['id']}: {e}")
feedback.warning(
"Cleanup Error", f"Failed to clean invalid entry {item['id']}: {e}"
)
# Clean duplicates
for item in results["duplicates"]:
try:
@@ -363,8 +380,10 @@ def _perform_cleanup(registry_service, results: dict, feedback):
registry_service._save_index(index)
cleaned_count += 1
except Exception as e:
feedback.warning("Cleanup Error", f"Failed to clean duplicate entry {item['id']}: {e}")
feedback.warning(
"Cleanup Error", f"Failed to clean duplicate entry {item['id']}: {e}"
)
# Clean old format entries
for item in results["old_format"]:
try:
@@ -374,6 +393,8 @@ def _perform_cleanup(registry_service, results: dict, feedback):
item["file"].unlink()
cleaned_count += 1
except Exception as e:
feedback.warning("Cleanup Error", f"Failed to clean old format entry {item['id']}: {e}")
feedback.warning(
"Cleanup Error", f"Failed to clean old format entry {item['id']}: {e}"
)
feedback.info("Cleanup Results", f"Successfully cleaned {cleaned_count} items")

View File

@@ -20,37 +20,32 @@ from ....utils.feedback import create_feedback_manager
"output_format",
type=click.Choice(["json", "csv", "xml"], case_sensitive=False),
default="json",
help="Export format"
help="Export format",
)
@click.option(
"--output",
"-o",
type=click.Path(),
help="Output file path (auto-generated if not specified)"
help="Output file path (auto-generated if not specified)",
)
@click.option(
"--include-metadata",
is_flag=True,
help="Include detailed media metadata in export"
"--include-metadata", is_flag=True, help="Include detailed media metadata in export"
)
@click.option(
"--status",
multiple=True,
type=click.Choice([
"watching", "completed", "planning", "dropped", "paused", "repeating"
], case_sensitive=False),
help="Only export specific status lists"
)
@click.option(
"--compress",
is_flag=True,
help="Compress the output file"
type=click.Choice(
["watching", "completed", "planning", "dropped", "paused", "repeating"],
case_sensitive=False,
),
help="Only export specific status lists",
)
@click.option("--compress", is_flag=True, help="Compress the output file")
@click.option(
"--api",
default="anilist",
type=click.Choice(["anilist"], case_sensitive=False),
help="Media API registry to export"
help="Media API registry to export",
)
@click.pass_obj
def export(
@@ -60,19 +55,19 @@ def export(
include_metadata: bool,
status: tuple[str, ...],
compress: bool,
api: str
api: str,
):
"""
Export your local media registry to various formats.
Supports JSON, CSV, and XML formats. Can optionally include
detailed metadata and compress the output.
"""
feedback = create_feedback_manager(config.general.icons)
try:
registry_service = MediaRegistryService(api, config.registry)
# Generate output filename if not specified
if not output:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -80,14 +75,12 @@ def export(
if compress:
extension += ".gz"
output = f"fastanime_registry_{api}_{timestamp}.{extension}"
output_path = Path(output)
# Get export data
export_data = _prepare_export_data(
registry_service, include_metadata, status
)
export_data = _prepare_export_data(registry_service, include_metadata, status)
# Export based on format
if output_format.lower() == "json":
_export_json(export_data, output_path, compress, feedback)
@@ -95,22 +88,25 @@ def export(
_export_csv(export_data, output_path, compress, feedback)
elif output_format.lower() == "xml":
_export_xml(export_data, output_path, compress, feedback)
feedback.success(
"Export Complete",
f"Registry exported to {output_path} ({_format_file_size(output_path)})"
f"Registry exported to {output_path} ({_format_file_size(output_path)})",
)
except Exception as e:
feedback.error("Export Error", f"Failed to export registry: {e}")
raise click.Abort()
def _prepare_export_data(registry_service, include_metadata: bool, status_filter: tuple[str, ...]) -> dict:
def _prepare_export_data(
registry_service, include_metadata: bool, status_filter: tuple[str, ...]
) -> dict:
"""Prepare data for export based on options."""
# Convert status filter to enums
from .....libs.media_api.types import UserMediaListStatus
status_map = {
"watching": UserMediaListStatus.WATCHING,
"completed": UserMediaListStatus.COMPLETED,
@@ -119,9 +115,9 @@ def _prepare_export_data(registry_service, include_metadata: bool, status_filter
"paused": UserMediaListStatus.PAUSED,
"repeating": UserMediaListStatus.REPEATING,
}
status_enums = [status_map[s] for s in status_filter] if status_filter else None
export_data = {
"metadata": {
"export_timestamp": datetime.now().isoformat(),
@@ -130,19 +126,19 @@ def _prepare_export_data(registry_service, include_metadata: bool, status_filter
"filtered_status": list(status_filter) if status_filter else None,
},
"statistics": registry_service.get_registry_stats(),
"media": []
"media": [],
}
# Get all records and filter by status if specified
all_records = registry_service.get_all_media_records()
for record in all_records:
index_entry = registry_service.get_media_index_entry(record.media_item.id)
# Skip if status filter is specified and doesn't match
if status_enums and (not index_entry or index_entry.status not in status_enums):
continue
media_data = {
"id": record.media_item.id,
"title": {
@@ -151,36 +147,63 @@ def _prepare_export_data(registry_service, include_metadata: bool, status_filter
"native": record.media_item.title.native,
},
"user_status": {
"status": index_entry.status.value if index_entry and index_entry.status else None,
"status": index_entry.status.value
if index_entry and index_entry.status
else None,
"progress": index_entry.progress if index_entry else None,
"score": index_entry.score if index_entry else None,
"last_watched": index_entry.last_watched.isoformat() if index_entry and index_entry.last_watched else None,
"last_watched": index_entry.last_watched.isoformat()
if index_entry and index_entry.last_watched
else None,
"notes": index_entry.notes if index_entry else None,
}
},
}
if include_metadata:
media_data.update({
"format": record.media_item.format.value if record.media_item.format else None,
"episodes": record.media_item.episodes,
"duration": record.media_item.duration,
"status": record.media_item.status.value if record.media_item.status else None,
"start_date": record.media_item.start_date.isoformat() if record.media_item.start_date else None,
"end_date": record.media_item.end_date.isoformat() if record.media_item.end_date else None,
"average_score": record.media_item.average_score,
"popularity": record.media_item.popularity,
"genres": [genre.value for genre in record.media_item.genres],
"tags": [{"name": tag.name.value, "rank": tag.rank} for tag in record.media_item.tags],
"studios": [studio.name for studio in record.media_item.studios if studio.name],
"description": record.media_item.description,
"cover_image": {
"large": record.media_item.cover_image.large if record.media_item.cover_image else None,
"medium": record.media_item.cover_image.medium if record.media_item.cover_image else None,
} if record.media_item.cover_image else None,
})
media_data.update(
{
"format": record.media_item.format.value
if record.media_item.format
else None,
"episodes": record.media_item.episodes,
"duration": record.media_item.duration,
"status": record.media_item.status.value
if record.media_item.status
else None,
"start_date": record.media_item.start_date.isoformat()
if record.media_item.start_date
else None,
"end_date": record.media_item.end_date.isoformat()
if record.media_item.end_date
else None,
"average_score": record.media_item.average_score,
"popularity": record.media_item.popularity,
"genres": [genre.value for genre in record.media_item.genres],
"tags": [
{"name": tag.name.value, "rank": tag.rank}
for tag in record.media_item.tags
],
"studios": [
studio.name
for studio in record.media_item.studios
if studio.name
],
"description": record.media_item.description,
"cover_image": {
"large": record.media_item.cover_image.large
if record.media_item.cover_image
else None,
"medium": record.media_item.cover_image.medium
if record.media_item.cover_image
else None,
}
if record.media_item.cover_image
else None,
}
)
export_data["media"].append(media_data)
return export_data
@@ -188,10 +211,11 @@ def _export_json(data: dict, output_path: Path, compress: bool, feedback):
"""Export data to JSON format."""
if compress:
import gzip
with gzip.open(output_path, 'wt', encoding='utf-8') as f:
with gzip.open(output_path, "wt", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
else:
with open(output_path, 'w', encoding='utf-8') as f:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
@@ -199,21 +223,38 @@ def _export_csv(data: dict, output_path: Path, compress: bool, feedback):
"""Export data to CSV format."""
# Flatten media data for CSV
fieldnames = [
"id", "title_english", "title_romaji", "title_native",
"status", "progress", "score", "last_watched", "notes"
"id",
"title_english",
"title_romaji",
"title_native",
"status",
"progress",
"score",
"last_watched",
"notes",
]
# Add metadata fields if included
if data["metadata"]["include_metadata"]:
fieldnames.extend([
"format", "episodes", "duration", "media_status", "start_date", "end_date",
"average_score", "popularity", "genres", "description"
])
fieldnames.extend(
[
"format",
"episodes",
"duration",
"media_status",
"start_date",
"end_date",
"average_score",
"popularity",
"genres",
"description",
]
)
def write_csv(file_obj):
writer = csv.DictWriter(file_obj, fieldnames=fieldnames)
writer.writeheader()
for media in data["media"]:
row = {
"id": media["id"],
@@ -226,29 +267,32 @@ def _export_csv(data: dict, output_path: Path, compress: bool, feedback):
"last_watched": media["user_status"]["last_watched"],
"notes": media["user_status"]["notes"],
}
if data["metadata"]["include_metadata"]:
row.update({
"format": media.get("format"),
"episodes": media.get("episodes"),
"duration": media.get("duration"),
"media_status": media.get("status"),
"start_date": media.get("start_date"),
"end_date": media.get("end_date"),
"average_score": media.get("average_score"),
"popularity": media.get("popularity"),
"genres": ",".join(media.get("genres", [])),
"description": media.get("description"),
})
row.update(
{
"format": media.get("format"),
"episodes": media.get("episodes"),
"duration": media.get("duration"),
"media_status": media.get("status"),
"start_date": media.get("start_date"),
"end_date": media.get("end_date"),
"average_score": media.get("average_score"),
"popularity": media.get("popularity"),
"genres": ",".join(media.get("genres", [])),
"description": media.get("description"),
}
)
writer.writerow(row)
if compress:
import gzip
with gzip.open(output_path, 'wt', encoding='utf-8', newline='') as f:
with gzip.open(output_path, "wt", encoding="utf-8", newline="") as f:
write_csv(f)
else:
with open(output_path, 'w', encoding='utf-8', newline='') as f:
with open(output_path, "w", encoding="utf-8", newline="") as f:
write_csv(f)
@@ -259,43 +303,43 @@ def _export_xml(data: dict, output_path: Path, compress: bool, feedback):
except ImportError:
feedback.error("XML Export Error", "XML export requires Python's xml module")
raise click.Abort()
root = ET.Element("fastanime_registry")
# Add metadata
metadata_elem = ET.SubElement(root, "metadata")
for key, value in data["metadata"].items():
if value is not None:
elem = ET.SubElement(metadata_elem, key)
elem.text = str(value)
# Add statistics
stats_elem = ET.SubElement(root, "statistics")
for key, value in data["statistics"].items():
if value is not None:
elem = ET.SubElement(stats_elem, key)
elem.text = str(value)
# Add media
media_list_elem = ET.SubElement(root, "media_list")
for media in data["media"]:
media_elem = ET.SubElement(media_list_elem, "media")
media_elem.set("id", str(media["id"]))
# Add titles
titles_elem = ET.SubElement(media_elem, "titles")
for title_type, title_value in media["title"].items():
if title_value:
title_elem = ET.SubElement(titles_elem, title_type)
title_elem.text = title_value
# Add user status
status_elem = ET.SubElement(media_elem, "user_status")
for key, value in media["user_status"].items():
if value is not None:
elem = ET.SubElement(status_elem, key)
elem.text = str(value)
# Add metadata if included
if data["metadata"]["include_metadata"]:
for key, value in media.items():
@@ -314,22 +358,23 @@ def _export_xml(data: dict, output_path: Path, compress: bool, feedback):
else:
elem = ET.SubElement(media_elem, key)
elem.text = str(value)
# Write XML
tree = ET.ElementTree(root)
if compress:
import gzip
with gzip.open(output_path, 'wb') as f:
tree.write(f, encoding='utf-8', xml_declaration=True)
with gzip.open(output_path, "wb") as f:
tree.write(f, encoding="utf-8", xml_declaration=True)
else:
tree.write(output_path, encoding='utf-8', xml_declaration=True)
tree.write(output_path, encoding="utf-8", xml_declaration=True)
def _format_file_size(file_path: Path) -> str:
"""Format file size in human-readable format."""
try:
size = file_path.stat().st_size
for unit in ['B', 'KB', 'MB', 'GB']:
for unit in ["B", "KB", "MB", "GB"]:
if size < 1024.0:
return f"{size:.1f} {unit}"
size /= 1024.0

View File

@@ -22,34 +22,26 @@ from ....utils.feedback import create_feedback_manager
"input_format",
type=click.Choice(["json", "csv", "xml", "auto"], case_sensitive=False),
default="auto",
help="Input format (auto-detect if not specified)"
help="Input format (auto-detect if not specified)",
)
@click.option(
"--merge",
is_flag=True,
help="Merge with existing registry (default: replace)"
"--merge", is_flag=True, help="Merge with existing registry (default: replace)"
)
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be imported without making changes"
"--dry-run", is_flag=True, help="Show what would be imported without making changes"
)
@click.option(
"--force",
"-f",
is_flag=True,
help="Force import even if format version doesn't match"
)
@click.option(
"--backup",
is_flag=True,
help="Create backup before importing"
help="Force import even if format version doesn't match",
)
@click.option("--backup", is_flag=True, help="Create backup before importing")
@click.option(
"--api",
default="anilist",
type=click.Choice(["anilist"], case_sensitive=False),
help="Media API registry to import to"
help="Media API registry to import to",
)
@click.pass_obj
def import_(
@@ -60,50 +52,50 @@ def import_(
dry_run: bool,
force: bool,
backup: bool,
api: str
api: str,
):
"""
Import media registry data from various formats.
Supports JSON, CSV, and XML formats exported by the export command
or compatible third-party tools.
"""
feedback = create_feedback_manager(config.general.icons)
try:
registry_service = MediaRegistryService(api, config.registry)
# Create backup if requested
if backup and not dry_run:
_create_backup(registry_service, feedback)
# Auto-detect format if needed
if input_format == "auto":
input_format = _detect_format(input_file)
feedback.info("Format Detection", f"Detected format: {input_format.upper()}")
feedback.info(
"Format Detection", f"Detected format: {input_format.upper()}"
)
# Parse input file
import_data = _parse_input_file(input_file, input_format, feedback)
# Validate import data
_validate_import_data(import_data, force, feedback)
# Import data
_import_data(
registry_service, import_data, merge, dry_run, feedback
)
_import_data(registry_service, import_data, merge, dry_run, feedback)
if not dry_run:
feedback.success(
"Import Complete",
f"Successfully imported {len(import_data.get('media', []))} media entries"
f"Successfully imported {len(import_data.get('media', []))} media entries",
)
else:
feedback.info(
"Dry Run Complete",
f"Would import {len(import_data.get('media', []))} media entries"
f"Would import {len(import_data.get('media', []))} media entries",
)
except Exception as e:
feedback.error("Import Error", f"Failed to import registry: {e}")
raise click.Abort()
@@ -112,40 +104,40 @@ def import_(
def _create_backup(registry_service, feedback):
"""Create a backup before importing."""
from .export import _prepare_export_data, _export_json
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = Path(f"fastanime_registry_backup_{timestamp}.json")
export_data = _prepare_export_data(registry_service, True, ())
_export_json(export_data, backup_path, False, feedback)
feedback.info("Backup Created", f"Registry backed up to {backup_path}")
def _detect_format(file_path: Path) -> str:
"""Auto-detect file format based on extension and content."""
extension = file_path.suffix.lower()
if extension in ['.json', '.gz']:
if extension in [".json", ".gz"]:
return "json"
elif extension == '.csv':
elif extension == ".csv":
return "csv"
elif extension == '.xml':
elif extension == ".xml":
return "xml"
# Try to detect by content
try:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read(100).strip()
if content.startswith('{') or content.startswith('['):
if content.startswith("{") or content.startswith("["):
return "json"
elif content.startswith('<?xml') or content.startswith('<'):
elif content.startswith("<?xml") or content.startswith("<"):
return "xml"
elif ',' in content: # Very basic CSV detection
elif "," in content: # Very basic CSV detection
return "csv"
except:
pass
raise click.ClickException(f"Could not detect format for {file_path}")
@@ -164,12 +156,13 @@ def _parse_input_file(file_path: Path, format_type: str, feedback) -> dict:
def _parse_json(file_path: Path) -> dict:
"""Parse JSON input file."""
try:
if file_path.suffix.lower() == '.gz':
if file_path.suffix.lower() == ".gz":
import gzip
with gzip.open(file_path, 'rt', encoding='utf-8') as f:
with gzip.open(file_path, "rt", encoding="utf-8") as f:
return json.load(f)
else:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
raise click.ClickException(f"Invalid JSON format: {e}")
@@ -182,11 +175,11 @@ def _parse_csv(file_path: Path) -> dict:
"import_timestamp": datetime.now().isoformat(),
"source_format": "csv",
},
"media": []
"media": [],
}
try:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
media_data = {
@@ -198,33 +191,47 @@ def _parse_csv(file_path: Path) -> dict:
},
"user_status": {
"status": row.get("status"),
"progress": int(row["progress"]) if row.get("progress") else None,
"progress": int(row["progress"])
if row.get("progress")
else None,
"score": float(row["score"]) if row.get("score") else None,
"last_watched": row.get("last_watched"),
"notes": row.get("notes"),
}
},
}
# Add metadata fields if present
if "format" in row:
media_data.update({
"format": row.get("format"),
"episodes": int(row["episodes"]) if row.get("episodes") else None,
"duration": int(row["duration"]) if row.get("duration") else None,
"media_status": row.get("media_status"),
"start_date": row.get("start_date"),
"end_date": row.get("end_date"),
"average_score": float(row["average_score"]) if row.get("average_score") else None,
"popularity": int(row["popularity"]) if row.get("popularity") else None,
"genres": row.get("genres", "").split(",") if row.get("genres") else [],
"description": row.get("description"),
})
media_data.update(
{
"format": row.get("format"),
"episodes": int(row["episodes"])
if row.get("episodes")
else None,
"duration": int(row["duration"])
if row.get("duration")
else None,
"media_status": row.get("media_status"),
"start_date": row.get("start_date"),
"end_date": row.get("end_date"),
"average_score": float(row["average_score"])
if row.get("average_score")
else None,
"popularity": int(row["popularity"])
if row.get("popularity")
else None,
"genres": row.get("genres", "").split(",")
if row.get("genres")
else [],
"description": row.get("description"),
}
)
import_data["media"].append(media_data)
except (ValueError, KeyError) as e:
raise click.ClickException(f"Invalid CSV format: {e}")
return import_data
@@ -234,22 +241,19 @@ def _parse_xml(file_path: Path) -> dict:
import xml.etree.ElementTree as ET
except ImportError:
raise click.ClickException("XML import requires Python's xml module")
try:
tree = ET.parse(file_path)
root = tree.getroot()
import_data = {
"metadata": {},
"media": []
}
import_data = {"metadata": {}, "media": []}
# Parse metadata
metadata_elem = root.find("metadata")
if metadata_elem is not None:
for child in metadata_elem:
import_data["metadata"][child.tag] = child.text
# Parse media
media_list_elem = root.find("media_list")
if media_list_elem is not None:
@@ -257,15 +261,15 @@ def _parse_xml(file_path: Path) -> dict:
media_data = {
"id": int(media_elem.get("id")),
"title": {},
"user_status": {}
"user_status": {},
}
# Parse titles
titles_elem = media_elem.find("titles")
if titles_elem is not None:
for title_elem in titles_elem:
media_data["title"][title_elem.tag] = title_elem.text
# Parse user status
status_elem = media_elem.find("user_status")
if status_elem is not None:
@@ -273,32 +277,38 @@ def _parse_xml(file_path: Path) -> dict:
value = child.text
if child.tag in ["progress", "score"] and value:
try:
value = float(value) if child.tag == "score" else int(value)
value = (
float(value) if child.tag == "score" else int(value)
)
except ValueError:
pass
media_data["user_status"][child.tag] = value
# Parse other metadata
for child in media_elem:
if child.tag not in ["titles", "user_status"]:
if child.tag in ["episodes", "duration", "popularity"]:
try:
media_data[child.tag] = int(child.text) if child.text else None
media_data[child.tag] = (
int(child.text) if child.text else None
)
except ValueError:
media_data[child.tag] = child.text
elif child.tag == "average_score":
try:
media_data[child.tag] = float(child.text) if child.text else None
media_data[child.tag] = (
float(child.text) if child.text else None
)
except ValueError:
media_data[child.tag] = child.text
else:
media_data[child.tag] = child.text
import_data["media"].append(media_data)
except ET.ParseError as e:
raise click.ClickException(f"Invalid XML format: {e}")
return import_data
@@ -306,36 +316,43 @@ def _validate_import_data(data: dict, force: bool, feedback):
"""Validate import data structure and compatibility."""
if "media" not in data:
raise click.ClickException("Import data missing 'media' section")
if not isinstance(data["media"], list):
raise click.ClickException("'media' section must be a list")
# Check if any media entries exist
if not data["media"]:
feedback.warning("No Media", "Import file contains no media entries")
return
# Validate media entries
required_fields = ["id", "title"]
for i, media in enumerate(data["media"]):
for field in required_fields:
if field not in media:
raise click.ClickException(f"Media entry {i} missing required field: {field}")
raise click.ClickException(
f"Media entry {i} missing required field: {field}"
)
if not isinstance(media.get("title"), dict):
raise click.ClickException(f"Media entry {i} has invalid title format")
feedback.info("Validation", f"Import data validated - {len(data['media'])} media entries")
feedback.info(
"Validation", f"Import data validated - {len(data['media'])} media entries"
)
def _import_data(registry_service, data: dict, merge: bool, dry_run: bool, feedback):
"""Import data into the registry."""
from .....libs.media_api.types import MediaFormat, MediaGenre, MediaStatus, MediaType
from .....libs.media_api.types import (
MediaFormat,
MediaType,
)
imported_count = 0
updated_count = 0
error_count = 0
status_map = {
"watching": UserMediaListStatus.WATCHING,
"completed": UserMediaListStatus.COMPLETED,
@@ -344,47 +361,47 @@ def _import_data(registry_service, data: dict, merge: bool, dry_run: bool, feedb
"paused": UserMediaListStatus.PAUSED,
"repeating": UserMediaListStatus.REPEATING,
}
for media_data in data["media"]:
try:
media_id = media_data["id"]
if not media_id:
error_count += 1
continue
title_data = media_data.get("title", {})
title = MediaTitle(
english=title_data.get("english") or "",
romaji=title_data.get("romaji"),
native=title_data.get("native"),
)
# Create minimal MediaItem for registry
media_item = MediaItem(
id=media_id,
title=title,
type=MediaType.ANIME, # Default to anime
)
# Add additional metadata if available
if "format" in media_data and media_data["format"]:
try:
media_item.format = getattr(MediaFormat, media_data["format"])
except (AttributeError, TypeError):
pass
if "episodes" in media_data:
media_item.episodes = media_data["episodes"]
if "average_score" in media_data:
media_item.average_score = media_data["average_score"]
if dry_run:
title_str = title.english or title.romaji or f"ID:{media_id}"
feedback.info("Would import", title_str)
imported_count += 1
continue
# Check if record exists
existing_record = registry_service.get_media_record(media_id)
if existing_record and not merge:
@@ -394,11 +411,11 @@ def _import_data(registry_service, data: dict, merge: bool, dry_run: bool, feedb
updated_count += 1
else:
imported_count += 1
# Create or update record
record = registry_service.get_or_create_record(media_item)
registry_service.save_media_record(record)
# Update user status if provided
user_status = media_data.get("user_status", {})
if user_status.get("status"):
@@ -412,14 +429,17 @@ def _import_data(registry_service, data: dict, merge: bool, dry_run: bool, feedb
score=user_status.get("score"),
notes=user_status.get("notes"),
)
except Exception as e:
error_count += 1
feedback.warning("Import Error", f"Failed to import media {media_data.get('id', 'unknown')}: {e}")
feedback.warning(
"Import Error",
f"Failed to import media {media_data.get('id', 'unknown')}: {e}",
)
continue
if not dry_run:
feedback.info(
"Import Summary",
f"Imported: {imported_count}, Updated: {updated_count}, Errors: {error_count}"
f"Imported: {imported_count}, Updated: {updated_count}, Errors: {error_count}",
)

View File

@@ -17,26 +17,19 @@ from ....utils.feedback import create_feedback_manager
@click.command(help="Restore registry from a backup file")
@click.argument("backup_file", type=click.Path(exists=True, path_type=Path))
@click.option(
"--force",
"-f",
is_flag=True,
help="Force restore even if current registry exists"
"--force", "-f", is_flag=True, help="Force restore even if current registry exists"
)
@click.option(
"--backup-current",
is_flag=True,
help="Create backup of current registry before restoring"
)
@click.option(
"--verify",
is_flag=True,
help="Verify backup integrity before restoring"
help="Create backup of current registry before restoring",
)
@click.option("--verify", is_flag=True, help="Verify backup integrity before restoring")
@click.option(
"--api",
default="anilist",
type=click.Choice(["anilist"], case_sensitive=False),
help="Media API registry to restore to"
help="Media API registry to restore to",
)
@click.pass_obj
def restore(
@@ -45,57 +38,66 @@ def restore(
force: bool,
backup_current: bool,
verify: bool,
api: str
api: str,
):
"""
Restore your media registry from a backup file.
Can restore from tar or zip backups created by the backup command.
Optionally creates a backup of the current registry before restoring.
"""
feedback = create_feedback_manager(config.general.icons)
try:
# Detect backup format
backup_format = _detect_backup_format(backup_file)
feedback.info("Backup Format", f"Detected {backup_format.upper()} format")
# Verify backup if requested
if verify:
if not _verify_backup(backup_file, backup_format, feedback):
feedback.error("Verification Failed", "Backup file appears to be corrupted")
feedback.error(
"Verification Failed", "Backup file appears to be corrupted"
)
raise click.Abort()
feedback.success("Verification", "Backup file integrity verified")
# Check if current registry exists
registry_service = MediaRegistryService(api, config.registry)
registry_exists = _check_registry_exists(registry_service)
if registry_exists and not force:
if not click.confirm("Current registry exists. Continue with restore?"):
feedback.info("Restore Cancelled", "No changes were made")
return
# Create backup of current registry if requested
if backup_current and registry_exists:
_backup_current_registry(registry_service, api, feedback)
# Show restore summary
_show_restore_summary(backup_file, backup_format, feedback)
# Perform restore
_perform_restore(backup_file, backup_format, config, api, feedback)
feedback.success("Restore Complete", "Registry has been successfully restored from backup")
feedback.success(
"Restore Complete", "Registry has been successfully restored from backup"
)
# Verify restored registry
try:
restored_service = MediaRegistryService(api, config.registry)
stats = restored_service.get_registry_stats()
feedback.info("Restored Registry", f"Contains {stats.get('total_media', 0)} media entries")
feedback.info(
"Restored Registry",
f"Contains {stats.get('total_media', 0)} media entries",
)
except Exception as e:
feedback.warning("Verification Warning", f"Could not verify restored registry: {e}")
feedback.warning(
"Verification Warning", f"Could not verify restored registry: {e}"
)
except Exception as e:
feedback.error("Restore Error", f"Failed to restore registry: {e}")
raise click.Abort()
@@ -103,27 +105,28 @@ def restore(
def _detect_backup_format(backup_file: Path) -> str:
"""Detect backup file format."""
if backup_file.suffix.lower() in ['.tar', '.gz']:
if backup_file.suffix.lower() in [".tar", ".gz"]:
return "tar"
elif backup_file.suffix.lower() == '.zip':
elif backup_file.suffix.lower() == ".zip":
return "zip"
elif backup_file.name.endswith('.tar.gz'):
elif backup_file.name.endswith(".tar.gz"):
return "tar"
else:
# Try to detect by content
try:
with tarfile.open(backup_file, 'r:*'):
with tarfile.open(backup_file, "r:*"):
return "tar"
except:
pass
try:
import zipfile
with zipfile.ZipFile(backup_file, 'r'):
with zipfile.ZipFile(backup_file, "r"):
return "zip"
except:
pass
raise click.ClickException(f"Could not detect backup format for {backup_file}")
@@ -131,53 +134,68 @@ def _verify_backup(backup_file: Path, format_type: str, feedback) -> bool:
"""Verify backup file integrity."""
try:
if format_type == "tar":
with tarfile.open(backup_file, 'r:*') as tar:
with tarfile.open(backup_file, "r:*") as tar:
# Check if essential files exist
names = tar.getnames()
has_registry = any('registry/' in name for name in names)
has_index = any('index/' in name for name in names)
has_metadata = 'backup_metadata.json' in names
has_registry = any("registry/" in name for name in names)
has_index = any("index/" in name for name in names)
has_metadata = "backup_metadata.json" in names
if not (has_registry and has_index):
return False
# Try to read metadata if it exists
if has_metadata:
try:
metadata_member = tar.getmember('backup_metadata.json')
metadata_member = tar.getmember("backup_metadata.json")
metadata_file = tar.extractfile(metadata_member)
if metadata_file:
import json
metadata = json.load(metadata_file)
feedback.info("Backup Info", f"Created: {metadata.get('backup_timestamp', 'Unknown')}")
feedback.info("Backup Info", f"Total Media: {metadata.get('total_media', 'Unknown')}")
feedback.info(
"Backup Info",
f"Created: {metadata.get('backup_timestamp', 'Unknown')}",
)
feedback.info(
"Backup Info",
f"Total Media: {metadata.get('total_media', 'Unknown')}",
)
except:
pass
else: # zip
import zipfile
with zipfile.ZipFile(backup_file, 'r') as zip_file:
with zipfile.ZipFile(backup_file, "r") as zip_file:
names = zip_file.namelist()
has_registry = any('registry/' in name for name in names)
has_index = any('index/' in name for name in names)
has_metadata = 'backup_metadata.json' in names
has_registry = any("registry/" in name for name in names)
has_index = any("index/" in name for name in names)
has_metadata = "backup_metadata.json" in names
if not (has_registry and has_index):
return False
# Try to read metadata
if has_metadata:
try:
with zip_file.open('backup_metadata.json') as metadata_file:
with zip_file.open("backup_metadata.json") as metadata_file:
import json
metadata = json.load(metadata_file)
feedback.info("Backup Info", f"Created: {metadata.get('backup_timestamp', 'Unknown')}")
feedback.info("Backup Info", f"Total Media: {metadata.get('total_media', 'Unknown')}")
feedback.info(
"Backup Info",
f"Created: {metadata.get('backup_timestamp', 'Unknown')}",
)
feedback.info(
"Backup Info",
f"Total Media: {metadata.get('total_media', 'Unknown')}",
)
except:
pass
return True
except Exception:
return False
@@ -186,7 +204,7 @@ def _check_registry_exists(registry_service) -> bool:
"""Check if a registry already exists."""
try:
stats = registry_service.get_registry_stats()
return stats.get('total_media', 0) > 0
return stats.get("total_media", 0) > 0
except:
return False
@@ -194,10 +212,10 @@ def _check_registry_exists(registry_service) -> bool:
def _backup_current_registry(registry_service, api: str, feedback):
"""Create backup of current registry before restoring."""
from .backup import _create_tar_backup
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = Path(f"fastanime_registry_pre_restore_{api}_{timestamp}.tar.gz")
try:
_create_tar_backup(registry_service, backup_path, True, False, feedback, api)
feedback.info("Current Registry Backed Up", f"Saved to {backup_path}")
@@ -209,72 +227,89 @@ def _show_restore_summary(backup_file: Path, format_type: str, feedback):
"""Show summary of what will be restored."""
try:
if format_type == "tar":
with tarfile.open(backup_file, 'r:*') as tar:
with tarfile.open(backup_file, "r:*") as tar:
members = tar.getmembers()
file_count = len([m for m in members if m.isfile()])
# Count media files
media_files = len([m for m in members if m.name.startswith('registry/') and m.name.endswith('.json')])
media_files = len(
[
m
for m in members
if m.name.startswith("registry/") and m.name.endswith(".json")
]
)
else: # zip
import zipfile
with zipfile.ZipFile(backup_file, 'r') as zip_file:
with zipfile.ZipFile(backup_file, "r") as zip_file:
info_list = zip_file.infolist()
file_count = len([info for info in info_list if not info.is_dir()])
# Count media files
media_files = len([info for info in info_list if info.filename.startswith('registry/') and info.filename.endswith('.json')])
media_files = len(
[
info
for info in info_list
if info.filename.startswith("registry/")
and info.filename.endswith(".json")
]
)
feedback.info("Restore Preview", f"Will restore {file_count} files")
feedback.info("Media Records", f"Contains {media_files} media entries")
except Exception as e:
feedback.warning("Preview Error", f"Could not analyze backup: {e}")
def _perform_restore(backup_file: Path, format_type: str, config: AppConfig, api: str, feedback):
def _perform_restore(
backup_file: Path, format_type: str, config: AppConfig, api: str, feedback
):
"""Perform the actual restore operation."""
# Create temporary extraction directory
temp_dir = Path(config.registry.media_dir.parent / "restore_temp")
temp_dir.mkdir(exist_ok=True)
try:
# Extract backup
if format_type == "tar":
with tarfile.open(backup_file, 'r:*') as tar:
with tarfile.open(backup_file, "r:*") as tar:
tar.extractall(temp_dir)
else: # zip
import zipfile
with zipfile.ZipFile(backup_file, 'r') as zip_file:
with zipfile.ZipFile(backup_file, "r") as zip_file:
zip_file.extractall(temp_dir)
feedback.info("Extraction", "Backup extracted to temporary directory")
# Remove existing registry if it exists
registry_dir = config.registry.media_dir / api
index_dir = config.registry.index_dir
if registry_dir.exists():
shutil.rmtree(registry_dir)
feedback.info("Cleanup", "Removed existing registry data")
if index_dir.exists():
shutil.rmtree(index_dir)
feedback.info("Cleanup", "Removed existing index data")
# Move extracted files to proper locations
extracted_registry = temp_dir / "registry" / api
extracted_index = temp_dir / "index"
if extracted_registry.exists():
shutil.move(str(extracted_registry), str(registry_dir))
feedback.info("Restore", "Registry data restored")
if extracted_index.exists():
shutil.move(str(extracted_index), str(index_dir))
feedback.info("Restore", "Index data restored")
# Restore cache if it exists
extracted_cache = temp_dir / "cache"
if extracted_cache.exists():
@@ -283,7 +318,7 @@ def _perform_restore(backup_file: Path, format_type: str, config: AppConfig, api
shutil.rmtree(cache_dir)
shutil.move(str(extracted_cache), str(cache_dir))
feedback.info("Restore", "Cache data restored")
finally:
# Clean up temporary directory
if temp_dir.exists():

View File

@@ -17,63 +17,44 @@ from ....utils.feedback import create_feedback_manager
@click.argument("query", required=False)
@click.option(
"--status",
type=click.Choice([
"watching", "completed", "planning", "dropped", "paused", "repeating"
], case_sensitive=False),
help="Filter by watch status"
type=click.Choice(
["watching", "completed", "planning", "dropped", "paused", "repeating"],
case_sensitive=False,
),
help="Filter by watch status",
)
@click.option(
"--genre",
multiple=True,
help="Filter by genre (can be used multiple times)"
"--genre", multiple=True, help="Filter by genre (can be used multiple times)"
)
@click.option(
"--format",
type=click.Choice([
"TV", "TV_SHORT", "MOVIE", "SPECIAL", "OVA", "ONA", "MUSIC"
], case_sensitive=False),
help="Filter by format"
)
@click.option(
"--year",
type=int,
help="Filter by release year"
)
@click.option(
"--min-score",
type=float,
help="Minimum average score (0.0 - 10.0)"
)
@click.option(
"--max-score",
type=float,
help="Maximum average score (0.0 - 10.0)"
type=click.Choice(
["TV", "TV_SHORT", "MOVIE", "SPECIAL", "OVA", "ONA", "MUSIC"],
case_sensitive=False,
),
help="Filter by format",
)
@click.option("--year", type=int, help="Filter by release year")
@click.option("--min-score", type=float, help="Minimum average score (0.0 - 10.0)")
@click.option("--max-score", type=float, help="Maximum average score (0.0 - 10.0)")
@click.option(
"--sort",
type=click.Choice([
"title", "score", "popularity", "year", "episodes", "updated"
], case_sensitive=False),
type=click.Choice(
["title", "score", "popularity", "year", "episodes", "updated"],
case_sensitive=False,
),
default="title",
help="Sort results by field"
help="Sort results by field",
)
@click.option("--limit", type=int, default=20, help="Maximum number of results to show")
@click.option(
"--limit",
type=int,
default=20,
help="Maximum number of results to show"
)
@click.option(
"--json",
"output_json",
is_flag=True,
help="Output results in JSON format"
"--json", "output_json", is_flag=True, help="Output results in JSON format"
)
@click.option(
"--api",
default="anilist",
type=click.Choice(["anilist"], case_sensitive=False),
help="Media API registry to search"
help="Media API registry to search",
)
@click.pass_obj
def search(
@@ -88,39 +69,40 @@ def search(
sort: str,
limit: int,
output_json: bool,
api: str
api: str,
):
"""
Search through your local media registry.
You can search by title and filter by various criteria like status,
genre, format, year, and score range.
"""
feedback = create_feedback_manager(config.general.icons)
console = Console()
try:
registry_service = MediaRegistryService(api, config.registry)
# Build search parameters
search_params = _build_search_params(
query, status, genre, format, year, min_score, max_score, sort, limit
)
# Perform search
result = registry_service.search_for_media(search_params)
if not result or not result.media:
feedback.info("No Results", "No media found matching your criteria")
return
if output_json:
import json
print(json.dumps(result.model_dump(), indent=2, default=str))
return
_display_search_results(console, result, config.general.icons)
except Exception as e:
feedback.error("Search Error", f"Failed to search registry: {e}")
raise click.Abort()
@@ -130,20 +112,20 @@ def _build_search_params(
query, status, genre, format, year, min_score, max_score, sort, limit
) -> MediaSearchParams:
"""Build MediaSearchParams from command options."""
# Convert status string to enum
status_enum = None
if status:
status_map = {
"watching": UserMediaListStatus.WATCHING,
"completed": UserMediaListStatus.COMPLETED,
"completed": UserMediaListStatus.COMPLETED,
"planning": UserMediaListStatus.PLANNING,
"dropped": UserMediaListStatus.DROPPED,
"paused": UserMediaListStatus.PAUSED,
"repeating": UserMediaListStatus.REPEATING,
}
status_enum = status_map.get(status.lower())
# Convert sort string to enum
sort_map = {
"title": MediaSort.TITLE_ROMAJI,
@@ -154,29 +136,33 @@ def _build_search_params(
"updated": MediaSort.UPDATED_AT_DESC,
}
sort_enum = sort_map.get(sort.lower(), MediaSort.TITLE_ROMAJI)
# Convert format string to enum if provided
format_enum = None
if format:
from .....libs.media_api.types import MediaFormat
format_enum = getattr(MediaFormat, format.upper(), None)
# Convert genre strings to enums
genre_enums = []
if genre:
from .....libs.media_api.types import MediaGenre
for g in genre:
# Try to find matching genre enum
for genre_enum in MediaGenre:
if genre_enum.value.lower() == g.lower():
genre_enums.append(genre_enum)
break
return MediaSearchParams(
query=query,
per_page=limit,
sort=[sort_enum],
averageScore_greater=min_score * 10 if min_score else None, # Convert to AniList scale
averageScore_greater=min_score * 10
if min_score
else None, # Convert to AniList scale
averageScore_lesser=max_score * 10 if max_score else None,
genre_in=genre_enums if genre_enums else None,
format_in=[format_enum] if format_enum else None,
@@ -187,8 +173,10 @@ def _build_search_params(
def _display_search_results(console: Console, result, icons: bool):
"""Display search results in a formatted table."""
table = Table(title=f"{'🔍 ' if icons else ''}Search Results ({len(result.media)} found)")
table = Table(
title=f"{'🔍 ' if icons else ''}Search Results ({len(result.media)} found)"
)
table.add_column("Title", style="cyan", min_width=30)
table.add_column("Year", style="dim", justify="center", min_width=6)
table.add_column("Format", style="magenta", justify="center", min_width=8)
@@ -196,31 +184,35 @@ def _display_search_results(console: Console, result, icons: bool):
table.add_column("Score", style="yellow", justify="center", min_width=6)
table.add_column("Status", style="blue", justify="center", min_width=10)
table.add_column("Progress", style="white", justify="center", min_width=8)
for media in result.media:
# Get title (prefer English, fallback to Romaji)
title = media.title.english or media.title.romaji or "Unknown"
if len(title) > 40:
title = title[:37] + "..."
# Get year from start date
year = ""
if media.start_date:
year = str(media.start_date.year)
# Format episodes
episodes = str(media.episodes) if media.episodes else "?"
# Format score
score = f"{media.average_score/10:.1f}" if media.average_score else "N/A"
score = f"{media.average_score / 10:.1f}" if media.average_score else "N/A"
# Get user status
status = "Not Listed"
progress = "0"
if media.user_status:
status = media.user_status.status.value.title() if media.user_status.status else "Unknown"
status = (
media.user_status.status.value.title()
if media.user_status.status
else "Unknown"
)
progress = f"{media.user_status.progress or 0}/{episodes}"
table.add_row(
title,
year,
@@ -228,11 +220,11 @@ def _display_search_results(console: Console, result, icons: bool):
episodes,
score,
status,
progress
progress,
)
console.print(table)
# Show pagination info if applicable
if result.page_info.total > len(result.media):
console.print(

View File

@@ -17,45 +17,43 @@ from ....utils.feedback import create_feedback_manager
"--detailed",
"-d",
is_flag=True,
help="Show detailed breakdown by genre, format, and year"
help="Show detailed breakdown by genre, format, and year",
)
@click.option(
"--json",
"output_json",
is_flag=True,
help="Output statistics in JSON format"
"--json", "output_json", is_flag=True, help="Output statistics in JSON format"
)
@click.option(
"--api",
default="anilist",
type=click.Choice(["anilist"], case_sensitive=False),
help="Media API to show stats for"
help="Media API to show stats for",
)
@click.pass_obj
def stats(config: AppConfig, detailed: bool, output_json: bool, api: str):
"""
Display comprehensive statistics about your local media registry.
Shows total counts, status breakdown, and optionally detailed
Shows total counts, status breakdown, and optionally detailed
analysis by genre, format, and release year.
"""
feedback = create_feedback_manager(config.general.icons)
console = Console()
try:
registry_service = MediaRegistryService(api, config.registry)
stats_data = registry_service.get_registry_stats()
if output_json:
import json
print(json.dumps(stats_data, indent=2, default=str))
return
_display_stats_overview(console, stats_data, api, config.general.icons)
if detailed:
_display_detailed_stats(console, stats_data, config.general.icons)
except Exception as e:
feedback.error("Stats Error", f"Failed to generate statistics: {e}")
raise click.Abort()
@@ -63,118 +61,122 @@ def stats(config: AppConfig, detailed: bool, output_json: bool, api: str):
def _display_stats_overview(console: Console, stats: dict, api: str, icons: bool):
"""Display basic registry statistics overview."""
# Main overview panel
overview_text = f"[bold cyan]Media API:[/bold cyan] {api.title()}\n"
overview_text += f"[bold cyan]Total Media:[/bold cyan] {stats.get('total_media', 0)}\n"
overview_text += f"[bold cyan]Registry Version:[/bold cyan] {stats.get('version', 'Unknown')}\n"
overview_text += f"[bold cyan]Last Updated:[/bold cyan] {stats.get('last_updated', 'Never')}\n"
overview_text += f"[bold cyan]Storage Size:[/bold cyan] {stats.get('storage_size', 'Unknown')}"
overview_text += (
f"[bold cyan]Total Media:[/bold cyan] {stats.get('total_media', 0)}\n"
)
overview_text += (
f"[bold cyan]Registry Version:[/bold cyan] {stats.get('version', 'Unknown')}\n"
)
overview_text += (
f"[bold cyan]Last Updated:[/bold cyan] {stats.get('last_updated', 'Never')}\n"
)
overview_text += (
f"[bold cyan]Storage Size:[/bold cyan] {stats.get('storage_size', 'Unknown')}"
)
panel = Panel(
overview_text,
title=f"{'📊 ' if icons else ''}Registry Overview",
border_style="cyan"
border_style="cyan",
)
console.print(panel)
console.print()
# Status breakdown table
status_breakdown = stats.get('status_breakdown', {})
status_breakdown = stats.get("status_breakdown", {})
if status_breakdown:
table = Table(title=f"{'📋 ' if icons else ''}Status Breakdown")
table.add_column("Status", style="cyan", no_wrap=True)
table.add_column("Count", style="magenta", justify="right")
table.add_column("Percentage", style="green", justify="right")
total = sum(status_breakdown.values())
for status, count in sorted(status_breakdown.items()):
percentage = (count / total * 100) if total > 0 else 0
table.add_row(
status.title(),
str(count),
f"{percentage:.1f}%"
)
table.add_row(status.title(), str(count), f"{percentage:.1f}%")
console.print(table)
console.print()
# Download status breakdown
download_stats = stats.get('download_stats', {})
download_stats = stats.get("download_stats", {})
if download_stats:
table = Table(title=f"{'💾 ' if icons else ''}Download Status")
table.add_column("Status", style="cyan", no_wrap=True)
table.add_column("Count", style="magenta", justify="right")
for status, count in download_stats.items():
table.add_row(status.title(), str(count))
console.print(table)
console.print()
def _display_detailed_stats(console: Console, stats: dict, icons: bool):
"""Display detailed breakdown by various categories."""
# Genre breakdown
genre_breakdown = stats.get('genre_breakdown', {})
genre_breakdown = stats.get("genre_breakdown", {})
if genre_breakdown:
table = Table(title=f"{'🎭 ' if icons else ''}Top Genres")
table.add_column("Genre", style="cyan")
table.add_column("Count", style="magenta", justify="right")
# Sort by count and show top 10
top_genres = sorted(genre_breakdown.items(), key=lambda x: x[1], reverse=True)[:10]
top_genres = sorted(genre_breakdown.items(), key=lambda x: x[1], reverse=True)[
:10
]
for genre, count in top_genres:
table.add_row(genre, str(count))
console.print(table)
console.print()
# Format breakdown
format_breakdown = stats.get('format_breakdown', {})
format_breakdown = stats.get("format_breakdown", {})
if format_breakdown:
table = Table(title=f"{'📺 ' if icons else ''}Format Breakdown")
table.add_column("Format", style="cyan")
table.add_column("Count", style="magenta", justify="right")
table.add_column("Percentage", style="green", justify="right")
total = sum(format_breakdown.values())
for format_type, count in sorted(format_breakdown.items()):
percentage = (count / total * 100) if total > 0 else 0
table.add_row(
format_type,
str(count),
f"{percentage:.1f}%"
)
table.add_row(format_type, str(count), f"{percentage:.1f}%")
console.print(table)
console.print()
# Year breakdown
year_breakdown = stats.get('year_breakdown', {})
year_breakdown = stats.get("year_breakdown", {})
if year_breakdown:
table = Table(title=f"{'📅 ' if icons else ''}Release Years (Top 10)")
table.add_column("Year", style="cyan", justify="center")
table.add_column("Count", style="magenta", justify="right")
# Sort by year descending and show top 10
top_years = sorted(year_breakdown.items(), key=lambda x: x[0], reverse=True)[:10]
top_years = sorted(year_breakdown.items(), key=lambda x: x[0], reverse=True)[
:10
]
for year, count in top_years:
table.add_row(str(year), str(count))
console.print(table)
console.print()
# Rating breakdown
rating_breakdown = stats.get('rating_breakdown', {})
rating_breakdown = stats.get("rating_breakdown", {})
if rating_breakdown:
table = Table(title=f"{'' if icons else ''}Score Distribution")
table.add_column("Score Range", style="cyan")
table.add_column("Count", style="magenta", justify="right")
for score_range, count in sorted(rating_breakdown.items()):
table.add_row(score_range, str(count))
console.print(table)
console.print()

View File

@@ -89,22 +89,23 @@ def search(config: AppConfig, **options: "Unpack[Options]"):
if not anime:
raise FastAnimeError(f"Failed to fetch anime {anime_result.title}")
available_episodes: list[str] = sorted(
getattr(anime.episodes, config.stream.translation_type), key=float
)
if options["episode_range"]:
from ..utils.parser import parse_episode_range
try:
episodes_range = parse_episode_range(
options["episode_range"],
available_episodes
options["episode_range"], available_episodes
)
for episode in episodes_range:
stream_anime(config, provider, selector, anime, episode, anime_title)
stream_anime(
config, provider, selector, anime, episode, anime_title
)
except (ValueError, IndexError) as e:
raise FastAnimeError(f"Invalid episode range: {e}") from e
else:

View File

@@ -53,14 +53,20 @@ if TYPE_CHECKING:
)
@click.pass_context
@click.pass_obj
def update(config: "AppConfig", ctx: click.Context, force: bool, check_only: bool, release_notes: bool) -> None:
def update(
config: "AppConfig",
ctx: click.Context,
force: bool,
check_only: bool,
release_notes: bool,
) -> None:
"""
Update FastAnime to the latest version.
This command checks for available updates and optionally updates
the application to the latest version from the configured sources
(pip, uv, pipx, git, or nix depending on installation method).
Args:
config: The application configuration object
ctx: The click context containing CLI options
@@ -72,73 +78,83 @@ def update(config: "AppConfig", ctx: click.Context, force: bool, check_only: boo
if release_notes:
print("[cyan]Fetching latest release notes...[/]")
is_latest, release_json = check_for_updates()
if not release_json:
print("[yellow]Could not fetch release information. Please check your internet connection.[/]")
print(
"[yellow]Could not fetch release information. Please check your internet connection.[/]"
)
sys.exit(1)
version = release_json.get('tag_name', 'unknown')
release_name = release_json.get('name', version)
release_body = release_json.get('body', 'No release notes available.')
published_at = release_json.get('published_at', 'unknown')
version = release_json.get("tag_name", "unknown")
release_name = release_json.get("name", version)
release_body = release_json.get("body", "No release notes available.")
published_at = release_json.get("published_at", "unknown")
console = Console()
print(f"[bold cyan]Release: {release_name}[/]")
print(f"[dim]Version: {version}[/]")
print(f"[dim]Published: {published_at}[/]")
print()
# Display release notes as markdown if available
if release_body.strip():
markdown = Markdown(release_body)
console.print(markdown)
else:
print("[dim]No release notes available for this version.[/]")
return
elif check_only:
print("[cyan]Checking for updates...[/]")
is_latest, release_json = check_for_updates()
if not release_json:
print("[yellow]Could not check for updates. Please check your internet connection.[/]")
print(
"[yellow]Could not check for updates. Please check your internet connection.[/]"
)
sys.exit(1)
if is_latest:
print("[green]FastAnime is up to date![/]")
print(f"[dim]Current version: {release_json.get('tag_name', 'unknown')}[/]")
print(
f"[dim]Current version: {release_json.get('tag_name', 'unknown')}[/]"
)
else:
latest_version = release_json.get('tag_name', 'unknown')
latest_version = release_json.get("tag_name", "unknown")
print(f"[yellow]Update available: {latest_version}[/]")
print(f"[dim]Run 'fastanime update' to update[/]")
print("[dim]Run 'fastanime update' to update[/]")
sys.exit(1)
else:
print("[cyan]Checking for updates and updating if necessary...[/]")
success, release_json = update_app(force=force)
if not release_json:
print("[red]Could not check for updates. Please check your internet connection.[/]")
print(
"[red]Could not check for updates. Please check your internet connection.[/]"
)
sys.exit(1)
if success:
latest_version = release_json.get('tag_name', 'unknown')
latest_version = release_json.get("tag_name", "unknown")
print(f"[green]Successfully updated to version {latest_version}![/]")
else:
if force:
print("[red]Update failed. Please check the error messages above.[/]")
print(
"[red]Update failed. Please check the error messages above.[/]"
)
sys.exit(1)
# If not forced and update failed, it might be because already up to date
# The update_app function already prints appropriate messages
except KeyboardInterrupt:
print("\n[yellow]Update cancelled by user.[/]")
sys.exit(1)
except Exception as e:
print(f"[red]An error occurred during update: {e}[/]")
# Get trace option from parent context
trace = ctx.parent.params.get('trace', False) if ctx.parent else False
trace = ctx.parent.params.get("trace", False) if ctx.parent else False
if trace:
raise
sys.exit(1)

View File

@@ -1,5 +1,4 @@
import configparser
import json
from pathlib import Path
from typing import Dict

View File

@@ -0,0 +1,150 @@
import json
import logging
import os
import tempfile
from pathlib import Path
from .....core.constants import APP_CACHE_DIR, SCRIPTS_DIR
from .....libs.media_api.params import MediaSearchParams
from ...session import Context, session
from ...state import InternalDirective, MediaApiState, MenuName, State
logger = logging.getLogger(__name__)
SEARCH_CACHE_DIR = APP_CACHE_DIR / "search"
SEARCH_RESULTS_FILE = SEARCH_CACHE_DIR / "current_search_results.json"
FZF_SCRIPTS_DIR = SCRIPTS_DIR / "fzf"
SEARCH_TEMPLATE_SCRIPT = (FZF_SCRIPTS_DIR / "search.template.sh").read_text(
encoding="utf-8"
)
@session.menu
def dynamic_search(ctx: Context, state: State) -> State | InternalDirective:
"""Dynamic search menu that provides real-time search results."""
feedback = ctx.service.feedback
feedback.clear_console()
# Ensure cache directory exists
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# Read the GraphQL search query
from .....libs.media_api.anilist import gql
search_query = gql.SEARCH_MEDIA.read_text(encoding="utf-8")
# Properly escape the GraphQL query for JSON
search_query_escaped = json.dumps(search_query)
# Prepare the search script
auth_header = ""
if ctx.media_api.is_authenticated() and hasattr(ctx.media_api, 'token'):
auth_header = f"Bearer {ctx.media_api.token}"
# Create a temporary search script
with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as temp_script:
script_content = SEARCH_TEMPLATE_SCRIPT
replacements = {
"GRAPHQL_ENDPOINT": "https://graphql.anilist.co",
"GRAPHQL_QUERY": search_query_escaped,
"CACHE_DIR": str(SEARCH_CACHE_DIR),
"SEARCH_RESULTS_FILE": str(SEARCH_RESULTS_FILE),
"AUTH_HEADER": auth_header,
}
for key, value in replacements.items():
script_content = script_content.replace(f"{{{key}}}", str(value))
temp_script.write(script_content)
temp_script_path = temp_script.name
try:
# Make the script executable
os.chmod(temp_script_path, 0o755)
# Use the selector's search functionality
try:
# Prepare preview functionality
preview_command = None
if ctx.config.general.preview != "none":
from ....utils.preview import create_preview_context
with create_preview_context() as preview_ctx:
preview_command = preview_ctx.get_dynamic_anime_preview(ctx.config)
choice = ctx.selector.search(
prompt="Search Anime",
search_command=f"bash {temp_script_path} {{q}}",
preview=preview_command,
header="Type to search for anime dynamically"
)
else:
choice = ctx.selector.search(
prompt="Search Anime",
search_command=f"bash {temp_script_path} {{q}}",
header="Type to search for anime dynamically"
)
except NotImplementedError:
feedback.error("Dynamic search is not supported by your current selector")
feedback.info("Please use the regular search option or switch to fzf selector")
return InternalDirective.MAIN
if not choice:
return InternalDirective.MAIN
# Read the cached search results
if not SEARCH_RESULTS_FILE.exists():
logger.error("Search results file not found")
return InternalDirective.MAIN
try:
with open(SEARCH_RESULTS_FILE, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
# Transform the raw data into MediaSearchResult
search_result = ctx.media_api.transform_raw_search_data(raw_data)
if not search_result or not search_result.media:
feedback.info("No results found")
return InternalDirective.MAIN
# Find the selected media item by matching the choice with the displayed format
selected_media = None
for media_item in search_result.media:
title = media_item.title.english or media_item.title.romaji or media_item.title.native or "Unknown"
year = media_item.start_date.year if media_item.start_date else "Unknown"
status = media_item.status.value if media_item.status else "Unknown"
genres = ", ".join([genre.value for genre in media_item.genres[:3]]) if media_item.genres else "Unknown"
display_format = f"{title} ({year}) [{status}] - {genres}"
if choice.strip() == display_format.strip():
selected_media = media_item
break
if not selected_media:
logger.error(f"Could not find selected media for choice: {choice}")
return InternalDirective.MAIN
# Navigate to media actions with the selected item
return State(
menu_name=MenuName.MEDIA_ACTIONS,
media_api=MediaApiState(
search_result={selected_media.id: selected_media},
media_id=selected_media.id,
search_params=MediaSearchParams(),
page_info=search_result.page_info,
),
)
except (json.JSONDecodeError, KeyError, Exception) as e:
logger.error(f"Error processing search results: {e}")
feedback.error("Failed to process search results")
return InternalDirective.MAIN
finally:
# Clean up the temporary script
try:
os.unlink(temp_script_path)
except OSError:
pass

View File

@@ -39,21 +39,34 @@ def episodes(ctx: Context, state: State) -> State | InternalDirective:
preview_command = None
if ctx.config.general.preview != "none":
from ....utils.preview import get_episode_preview
from ....utils.preview import create_preview_context
preview_command = get_episode_preview(
available_episodes, media_item, ctx.config
with create_preview_context() as preview_ctx:
preview_command = preview_ctx.get_episode_preview(
available_episodes, media_item, ctx.config
)
chosen_episode_str = ctx.selector.choose(
prompt="Select Episode", choices=choices, preview=preview_command
)
if not chosen_episode_str or chosen_episode_str == "Back":
# TODO: should improve the back logic for menus that can be pass through
return InternalDirective.BACKX2
chosen_episode = chosen_episode_str
# Workers are automatically cleaned up when exiting the context
else:
# No preview mode
chosen_episode_str = ctx.selector.choose(
prompt="Select Episode", choices=choices, preview=None
)
chosen_episode_str = ctx.selector.choose(
prompt="Select Episode", choices=choices, preview=preview_command
)
if not chosen_episode_str or chosen_episode_str == "Back":
# TODO: should improve the back logic for menus that can be pass through
return InternalDirective.BACKX2
if not chosen_episode_str or chosen_episode_str == "Back":
# TODO: should improve the back logic for menus that can be pass through
return InternalDirective.BACKX2
chosen_episode = chosen_episode_str
chosen_episode = chosen_episode_str
# Track episode selection in watch history (if enabled in config)
if (

View File

@@ -39,6 +39,7 @@ def main(ctx: Context, state: State) -> State | InternalDirective:
ctx, state, UserMediaListStatus.PLANNING
),
f"{'🔎 ' if icons else ''}Search": _create_search_media_list(ctx, state),
f"{'🔍 ' if icons else ''}Dynamic Search": _create_dynamic_search_action(ctx, state),
f"{'🏠 ' if icons else ''}Downloads": _create_downloads_action(ctx, state),
f"{'🔔 ' if icons else ''}Recently Updated": _create_media_list_action(
ctx, state, MediaSort.UPDATED_AT_DESC
@@ -228,3 +229,12 @@ def _create_downloads_action(ctx: Context, state: State) -> MenuAction:
return State(menu_name=MenuName.DOWNLOADS)
return action
def _create_dynamic_search_action(ctx: Context, state: State) -> MenuAction:
"""Create action to navigate to the dynamic search menu."""
def action():
return State(menu_name=MenuName.DYNAMIC_SEARCH)
return action

View File

@@ -177,7 +177,9 @@ def _view_info(ctx: Context, state: State) -> MenuAction:
image.render_image(cover_image.large)
# Create main title
main_title = media_item.title.english or media_item.title.romaji or "Unknown Title"
main_title = (
media_item.title.english or media_item.title.romaji or "Unknown Title"
)
title_text = Text(main_title, style="bold cyan")
# Create info table
@@ -189,7 +191,7 @@ def _view_info(ctx: Context, state: State) -> MenuAction:
info_table.add_row("English Title", media_item.title.english or "N/A")
info_table.add_row("Romaji Title", media_item.title.romaji or "N/A")
info_table.add_row("Native Title", media_item.title.native or "N/A")
if media_item.synonymns:
synonyms = ", ".join(media_item.synonymns[:3]) # Show first 3 synonyms
if len(media_item.synonymns) > 3:
@@ -197,10 +199,19 @@ def _view_info(ctx: Context, state: State) -> MenuAction:
info_table.add_row("Synonyms", synonyms)
info_table.add_row("Type", media_item.type.value if media_item.type else "N/A")
info_table.add_row("Format", media_item.format.value if media_item.format else "N/A")
info_table.add_row("Status", media_item.status.value if media_item.status else "N/A")
info_table.add_row("Episodes", str(media_item.episodes) if media_item.episodes else "Unknown")
info_table.add_row("Duration", f"{media_item.duration} min" if media_item.duration else "Unknown")
info_table.add_row(
"Format", media_item.format.value if media_item.format else "N/A"
)
info_table.add_row(
"Status", media_item.status.value if media_item.status else "N/A"
)
info_table.add_row(
"Episodes", str(media_item.episodes) if media_item.episodes else "Unknown"
)
info_table.add_row(
"Duration",
f"{media_item.duration} min" if media_item.duration else "Unknown",
)
# Add dates
if media_item.start_date:
@@ -229,63 +240,72 @@ def _view_info(ctx: Context, state: State) -> MenuAction:
Text(genres_text, style="green"),
title="[bold]Genres[/bold]",
border_style="green",
box=box.ROUNDED
box=box.ROUNDED,
)
else:
genres_panel = Panel(
Text("No genres available", style="dim"),
title="[bold]Genres[/bold]",
border_style="green",
box=box.ROUNDED
box=box.ROUNDED,
)
# Create tags panel (show top tags)
if media_item.tags:
top_tags = sorted(media_item.tags, key=lambda x: x.rank or 0, reverse=True)[:10]
top_tags = sorted(media_item.tags, key=lambda x: x.rank or 0, reverse=True)[
:10
]
tags_text = ", ".join([tag.name.value for tag in top_tags])
tags_panel = Panel(
Text(tags_text, style="yellow"),
title="[bold]Tags[/bold]",
border_style="yellow",
box=box.ROUNDED
box=box.ROUNDED,
)
else:
tags_panel = Panel(
Text("No tags available", style="dim"),
title="[bold]Tags[/bold]",
border_style="yellow",
box=box.ROUNDED
box=box.ROUNDED,
)
# Create studios panel
if media_item.studios:
studios_text = ", ".join([studio.name for studio in media_item.studios if studio.name])
studios_text = ", ".join(
[studio.name for studio in media_item.studios if studio.name]
)
studios_panel = Panel(
Text(studios_text, style="blue"),
title="[bold]Studios[/bold]",
border_style="blue",
box=box.ROUNDED
box=box.ROUNDED,
)
else:
studios_panel = Panel(
Text("No studio information", style="dim"),
title="[bold]Studios[/bold]",
border_style="blue",
box=box.ROUNDED
box=box.ROUNDED,
)
# Create description panel
description = media_item.description or "No description available"
# Clean HTML tags from description
clean_description = re.sub(r'<[^>]+>', '', description)
clean_description = re.sub(r"<[^>]+>", "", description)
# Replace common HTML entities
clean_description = clean_description.replace('&quot;', '"').replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>')
clean_description = (
clean_description.replace("&quot;", '"')
.replace("&amp;", "&")
.replace("&lt;", "<")
.replace("&gt;", ">")
)
description_panel = Panel(
Text(clean_description, style="white"),
title="[bold]Description[/bold]",
border_style="cyan",
box=box.ROUNDED
box=box.ROUNDED,
)
# Create user status panel if available
@@ -293,35 +313,44 @@ def _view_info(ctx: Context, state: State) -> MenuAction:
user_info_table = Table(show_header=False, box=box.SIMPLE)
user_info_table.add_column("Field", style="bold magenta")
user_info_table.add_column("Value", style="white")
if media_item.user_status.status:
user_info_table.add_row("Status", media_item.user_status.status.value.title())
user_info_table.add_row(
"Status", media_item.user_status.status.value.title()
)
if media_item.user_status.progress is not None:
progress = f"{media_item.user_status.progress}/{media_item.episodes or '?'}"
progress = (
f"{media_item.user_status.progress}/{media_item.episodes or '?'}"
)
user_info_table.add_row("Progress", progress)
if media_item.user_status.score:
user_info_table.add_row("Your Score", f"{media_item.user_status.score}/10")
user_info_table.add_row(
"Your Score", f"{media_item.user_status.score}/10"
)
if media_item.user_status.repeat:
user_info_table.add_row("Rewatched", f"{media_item.user_status.repeat} times")
user_info_table.add_row(
"Rewatched", f"{media_item.user_status.repeat} times"
)
user_panel = Panel(
user_info_table,
title="[bold]Your List Status[/bold]",
border_style="magenta",
box=box.ROUNDED
box=box.ROUNDED,
)
else:
user_panel = None
# Create next airing panel if available
if media_item.next_airing:
from datetime import datetime
airing_info_table = Table(show_header=False, box=box.SIMPLE)
airing_info_table.add_column("Field", style="bold red")
airing_info_table.add_column("Value", style="white")
airing_info_table.add_row("Next Episode", str(media_item.next_airing.episode))
airing_info_table.add_row(
"Next Episode", str(media_item.next_airing.episode)
)
if media_item.next_airing.airing_at:
air_date = media_item.next_airing.airing_at.strftime("%Y-%m-%d %H:%M")
airing_info_table.add_row("Air Date", air_date)
@@ -330,7 +359,7 @@ def _view_info(ctx: Context, state: State) -> MenuAction:
airing_info_table,
title="[bold]Next Airing[/bold]",
border_style="red",
box=box.ROUNDED
box=box.ROUNDED,
)
else:
airing_panel = None
@@ -340,30 +369,30 @@ def _view_info(ctx: Context, state: State) -> MenuAction:
info_table,
title="[bold]Basic Information[/bold]",
border_style="cyan",
box=box.ROUNDED
box=box.ROUNDED,
)
# Display everything
console.print(Panel(title_text, box=box.DOUBLE, border_style="bright_cyan"))
console.print()
# Create columns for better layout
panels_row1 = [info_panel, genres_panel]
if user_panel:
panels_row1.append(user_panel)
console.print(Columns(panels_row1, equal=True, expand=True))
console.print()
panels_row2 = [tags_panel, studios_panel]
if airing_panel:
panels_row2.append(airing_panel)
console.print(Columns(panels_row2, equal=True, expand=True))
console.print()
console.print(description_panel)
ctx.selector.ask("Press Enter to continue...")
return InternalDirective.RELOAD
@@ -479,7 +508,7 @@ def _view_characters(ctx: Context, state: State) -> MenuAction:
loading_message = "Fetching characters..."
characters_data = None
with feedback.progress(loading_message):
characters_data = ctx.media_api.get_characters_of(
MediaCharactersParams(id=media_item.id)
@@ -487,8 +516,8 @@ def _view_characters(ctx: Context, state: State) -> MenuAction:
if not characters_data or not characters_data.get("data"):
feedback.warning(
"No character information found",
"This anime doesn't have character data available"
"No character information found",
"This anime doesn't have character data available",
)
return InternalDirective.RELOAD
@@ -496,7 +525,7 @@ def _view_characters(ctx: Context, state: State) -> MenuAction:
# Extract characters from the nested response structure
page_data = characters_data["data"]["Page"]["media"][0]
characters = page_data["characters"]["nodes"]
if not characters:
feedback.warning("No characters found for this anime")
return InternalDirective.RELOAD
@@ -506,7 +535,6 @@ def _view_characters(ctx: Context, state: State) -> MenuAction:
from rich.table import Table
from rich.panel import Panel
from rich.text import Text
from datetime import datetime
console = Console()
console.clear()
@@ -528,12 +556,15 @@ def _view_characters(ctx: Context, state: State) -> MenuAction:
gender = char.get("gender") or "Unknown"
age = str(char.get("age") or "Unknown")
favorites = str(char.get("favourites") or "0")
# Clean up description (remove HTML tags and truncate)
description = char.get("description") or "No description"
if description:
import re
description = re.sub(r'<[^>]+>', '', description) # Remove HTML tags
description = re.sub(
r"<[^>]+>", "", description
) # Remove HTML tags
if len(description) > 100:
description = description[:97] + "..."
@@ -542,12 +573,12 @@ def _view_characters(ctx: Context, state: State) -> MenuAction:
# Display in a panel
panel = Panel(table, title=title, border_style="blue")
console.print(panel)
ctx.selector.ask("Press Enter to continue...")
except (KeyError, IndexError, TypeError) as e:
feedback.error(f"Error displaying characters: {e}")
return InternalDirective.RELOAD
return action
@@ -564,7 +595,7 @@ def _view_airing_schedule(ctx: Context, state: State) -> MenuAction:
loading_message = "Fetching airing schedule..."
schedule_data = None
with feedback.progress(loading_message):
schedule_data = ctx.media_api.get_airing_schedule_for(
MediaAiringScheduleParams(id=media_item.id)
@@ -572,8 +603,8 @@ def _view_airing_schedule(ctx: Context, state: State) -> MenuAction:
if not schedule_data or not schedule_data.get("data"):
feedback.warning(
"No airing schedule found",
"This anime doesn't have upcoming episodes or airing data"
"No airing schedule found",
"This anime doesn't have upcoming episodes or airing data",
)
return InternalDirective.RELOAD
@@ -581,11 +612,11 @@ def _view_airing_schedule(ctx: Context, state: State) -> MenuAction:
# Extract schedule from the nested response structure
page_data = schedule_data["data"]["Page"]["media"][0]
schedule_nodes = page_data["airingSchedule"]["nodes"]
if not schedule_nodes:
feedback.info(
"No upcoming episodes",
"This anime has no scheduled upcoming episodes"
"This anime has no scheduled upcoming episodes",
)
return InternalDirective.RELOAD
@@ -611,7 +642,7 @@ def _view_airing_schedule(ctx: Context, state: State) -> MenuAction:
for episode in schedule_nodes[:10]: # Show next 10 episodes
ep_num = str(episode.get("episode", "?"))
# Format air date
airing_at = episode.get("airingAt")
if airing_at:
@@ -619,14 +650,14 @@ def _view_airing_schedule(ctx: Context, state: State) -> MenuAction:
formatted_date = air_date.strftime("%Y-%m-%d %H:%M")
else:
formatted_date = "Unknown"
# Format time until airing
time_until = episode.get("timeUntilAiring")
if time_until:
days = time_until // 86400
hours = (time_until % 86400) // 3600
minutes = (time_until % 3600) // 60
if days > 0:
time_str = f"{days}d {hours}h {minutes}m"
elif hours > 0:
@@ -641,12 +672,12 @@ def _view_airing_schedule(ctx: Context, state: State) -> MenuAction:
# Display in a panel
panel = Panel(table, title=title, border_style="blue")
console.print(panel)
ctx.selector.ask("Press Enter to continue...")
except (KeyError, IndexError, TypeError) as e:
feedback.error(f"Error displaying airing schedule: {e}")
return InternalDirective.RELOAD
return action

View File

@@ -30,9 +30,12 @@ def provider_search(ctx: Context, state: State) -> State | InternalDirective:
return InternalDirective.BACK
provider_search_results = provider.search(
SearchParams(query=normalize_title(media_title, config.general.provider.value,True), translation_type=config.stream.translation_type)
SearchParams(
query=normalize_title(media_title, config.general.provider.value, True),
translation_type=config.stream.translation_type,
)
)
if not provider_search_results or not provider_search_results.results:
feedback.warning(
f"Could not find '{media_title}' on {provider.__class__.__name__}",
@@ -51,7 +54,10 @@ def provider_search(ctx: Context, state: State) -> State | InternalDirective:
# Use fuzzy matching to find the best title
best_match_title = max(
provider_results_map.keys(),
key=lambda p_title: fuzz.ratio(normalize_title(p_title,config.general.provider.value).lower(), media_title.lower()),
key=lambda p_title: fuzz.ratio(
normalize_title(p_title, config.general.provider.value).lower(),
media_title.lower(),
),
)
feedback.info("Auto-selecting best match: {best_match_title}")
selected_provider_anime = provider_results_map[best_match_title]

View File

@@ -23,17 +23,6 @@ def results(ctx: Context, state: State) -> State | InternalDirective:
_format_title(ctx, media_item): media_item
for media_item in search_result.values()
}
preview_command = None
if ctx.config.general.preview != "none":
from ....utils.preview import get_anime_preview
preview_command = get_anime_preview(
list(search_result_dict.values()),
list(search_result_dict.keys()),
ctx.config,
)
choices: Dict[str, Callable[[], Union[int, State, InternalDirective]]] = {
title: lambda media_id=item.id: media_id
for title, item in search_result_dict.items()
@@ -64,11 +53,31 @@ def results(ctx: Context, state: State) -> State | InternalDirective:
}
)
choice = ctx.selector.choose(
prompt="Select Anime",
choices=list(choices),
preview=preview_command,
)
preview_command = None
if ctx.config.general.preview != "none":
from ....utils.preview import create_preview_context
with create_preview_context() as preview_ctx:
preview_command = preview_ctx.get_anime_preview(
list(search_result_dict.values()),
list(search_result_dict.keys()),
ctx.config,
)
choice = ctx.selector.choose(
prompt="Select Anime",
choices=list(choices),
preview=preview_command,
)
else:
# No preview mode
choice = ctx.selector.choose(
prompt="Select Anime",
choices=list(choices),
preview=None,
)
if not choice:
return InternalDirective.RELOAD

View File

@@ -73,7 +73,11 @@ def servers(ctx: Context, state: State) -> State | InternalDirective:
)
return InternalDirective.RELOAD
final_title = f"{provider_anime.title} - Ep {episode_number}"
final_title = (
media_item.streaming_episodes[episode_number].title
if media_item.streaming_episodes.get(episode_number)
else f"{media_item.title.english} - Ep {episode_number}"
)
feedback.info(f"[bold green]Launching player for:[/] {final_title}")
player_result = ctx.player.play(

View File

@@ -134,8 +134,21 @@ class Session:
except Exception:
self._context.service.session.create_crash_backup(self._history)
raise
finally:
# Clean up preview workers when session ends
self._cleanup_preview_workers()
self._context.service.session.save_session(self._history)
def _cleanup_preview_workers(self):
"""Clean up preview workers when session ends."""
try:
from ..utils.preview import shutdown_preview_workers
shutdown_preview_workers(wait=False, timeout=5.0)
logger.debug("Preview workers cleaned up successfully")
except Exception as e:
logger.warning(f"Failed to cleanup preview workers: {e}")
def _run_main_loop(self):
"""Run the main session loop."""
while self._history:

View File

@@ -40,6 +40,7 @@ class MenuName(Enum):
SESSION_MANAGEMENT = "SESSION_MANAGEMENT"
MEDIA_ACTIONS = "MEDIA_ACTIONS"
DOWNLOADS = "DOWNLOADS"
DYNAMIC_SEARCH = "DYNAMIC_SEARCH"
class StateModel(BaseModel):

View File

@@ -4,7 +4,7 @@ import logging
from pathlib import Path
from typing import Optional
from ....core.config.model import AppConfig, DownloadsConfig
from ....core.config.model import AppConfig
from ....core.downloader.base import BaseDownloader
from ....core.downloader.downloader import create_downloader
from ....core.downloader.params import DownloadParams
@@ -51,30 +51,34 @@ class DownloadService:
) -> bool:
"""
Download a specific episode and record it in the registry.
Args:
media_item: The media item to download
episode_number: The episode number to download
server: Optional specific server to use for download
quality: Optional quality preference
force_redownload: Whether to redownload if already exists
Returns:
bool: True if download was successful, False otherwise
"""
try:
# Get or create media record
media_record = self.media_registry.get_or_create_record(media_item)
# Check if episode already exists and is completed
existing_episode = self._find_episode_in_record(media_record, episode_number)
existing_episode = self._find_episode_in_record(
media_record, episode_number
)
if (
existing_episode
existing_episode
and existing_episode.download_status == DownloadStatus.COMPLETED
and not force_redownload
and existing_episode.file_path.exists()
):
logger.info(f"Episode {episode_number} already downloaded at {existing_episode.file_path}")
logger.info(
f"Episode {episode_number} already downloaded at {existing_episode.file_path}"
)
return True
# Generate file path
@@ -130,8 +134,10 @@ class DownloadService:
file_size=file_size,
subtitle_paths=download_result.subtitle_paths,
)
logger.info(f"Successfully downloaded episode {episode_number} to {download_result.video_path}")
logger.info(
f"Successfully downloaded episode {episode_number} to {download_result.video_path}"
)
else:
# Update episode record with failure
self.media_registry.update_episode_download_status(
@@ -140,8 +146,10 @@ class DownloadService:
status=DownloadStatus.FAILED,
error_message=download_result.error_message,
)
logger.error(f"Failed to download episode {episode_number}: {download_result.error_message}")
logger.error(
f"Failed to download episode {episode_number}: {download_result.error_message}"
)
return download_result.success
@@ -157,7 +165,7 @@ class DownloadService:
)
except Exception as cleanup_error:
logger.error(f"Failed to update failed status: {cleanup_error}")
return False
def download_multiple_episodes(
@@ -169,18 +177,18 @@ class DownloadService:
) -> dict[str, bool]:
"""
Download multiple episodes and return success status for each.
Args:
media_item: The media item to download
episode_numbers: List of episode numbers to download
quality: Optional quality preference
force_redownload: Whether to redownload if already exists
Returns:
dict: Mapping of episode_number -> success status
"""
results = {}
for episode_number in episode_numbers:
success = self.download_episode(
media_item=media_item,
@@ -189,18 +197,22 @@ class DownloadService:
force_redownload=force_redownload,
)
results[episode_number] = success
# Log progress
logger.info(f"Download progress: {episode_number} - {'' if success else ''}")
logger.info(
f"Download progress: {episode_number} - {'' if success else ''}"
)
return results
def get_download_status(self, media_item: MediaItem, episode_number: str) -> Optional[DownloadStatus]:
def get_download_status(
self, media_item: MediaItem, episode_number: str
) -> Optional[DownloadStatus]:
"""Get the download status for a specific episode."""
media_record = self.media_registry.get_media_record(media_item.id)
if not media_record:
return None
episode_record = self._find_episode_in_record(media_record, episode_number)
return episode_record.download_status if episode_record else None
@@ -209,7 +221,7 @@ class DownloadService:
media_record = self.media_registry.get_media_record(media_item.id)
if not media_record:
return []
return [
episode.episode_number
for episode in media_record.media_episodes
@@ -217,38 +229,43 @@ class DownloadService:
and episode.file_path.exists()
]
def remove_downloaded_episode(self, media_item: MediaItem, episode_number: str) -> bool:
def remove_downloaded_episode(
self, media_item: MediaItem, episode_number: str
) -> bool:
"""Remove a downloaded episode file and update registry."""
try:
media_record = self.media_registry.get_media_record(media_item.id)
if not media_record:
return False
episode_record = self._find_episode_in_record(media_record, episode_number)
if not episode_record:
return False
# Remove file if it exists
if episode_record.file_path.exists():
episode_record.file_path.unlink()
# Remove episode from record
media_record.media_episodes = [
ep for ep in media_record.media_episodes
ep
for ep in media_record.media_episodes
if ep.episode_number != episode_number
]
# Save updated record
self.media_registry.save_media_record(media_record)
logger.info(f"Removed downloaded episode {episode_number}")
return True
except Exception as e:
logger.error(f"Error removing episode {episode_number}: {e}")
return False
def _find_episode_in_record(self, media_record, episode_number: str) -> Optional[MediaEpisode]:
def _find_episode_in_record(
self, media_record, episode_number: str
) -> Optional[MediaEpisode]:
"""Find an episode record by episode number."""
for episode in media_record.media_episodes:
if episode.episode_number == episode_number:
@@ -303,7 +320,7 @@ class DownloadService:
"""Download episode from a specific server."""
anime_title = media_item.title.english or media_item.title.romaji or "Unknown"
episode_title = server.episode_title or f"Episode {episode_number}"
try:
# Get the best quality link from server
if not server.links:
@@ -319,7 +336,9 @@ class DownloadService:
episode_title=episode_title,
silent=True, # Use True by default since there's no verbose in config
headers=server.headers,
subtitles=[sub.url for sub in server.subtitles] if server.subtitles else [],
subtitles=[sub.url for sub in server.subtitles]
if server.subtitles
else [],
vid_format=self.downloads_config.preferred_quality,
force_unknown_ext=True,
)
@@ -333,6 +352,7 @@ class DownloadService:
except Exception as e:
logger.error(f"Error during download: {e}")
from ....core.downloader.model import DownloadResult
return DownloadResult(
success=False,
error_message=str(e),
@@ -346,28 +366,34 @@ class DownloadService:
def get_failed_downloads(self) -> list[tuple[int, str]]:
"""Get all episodes that failed to download."""
return self.media_registry.get_episodes_by_download_status(DownloadStatus.FAILED)
return self.media_registry.get_episodes_by_download_status(
DownloadStatus.FAILED
)
def get_queued_downloads(self) -> list[tuple[int, str]]:
"""Get all episodes queued for download."""
return self.media_registry.get_episodes_by_download_status(DownloadStatus.QUEUED)
return self.media_registry.get_episodes_by_download_status(
DownloadStatus.QUEUED
)
def retry_failed_downloads(self, max_retries: int = 3) -> dict[str, bool]:
"""Retry all failed downloads up to max_retries."""
failed_episodes = self.get_failed_downloads()
results = {}
for media_id, episode_number in failed_episodes:
# Get the media record to check retry attempts
media_record = self.media_registry.get_media_record(media_id)
if not media_record:
continue
episode_record = self._find_episode_in_record(media_record, episode_number)
if not episode_record or episode_record.download_attempts >= max_retries:
logger.info(f"Skipping {media_id}:{episode_number} - max retries exceeded")
logger.info(
f"Skipping {media_id}:{episode_number} - max retries exceeded"
)
continue
logger.info(f"Retrying download for {media_id}:{episode_number}")
success = self.download_episode(
media_item=media_record.media_item,
@@ -375,40 +401,41 @@ class DownloadService:
force_redownload=True,
)
results[f"{media_id}:{episode_number}"] = success
return results
def cleanup_failed_downloads(self, older_than_days: int = 7) -> int:
"""Clean up failed download records older than specified days."""
from datetime import datetime, timedelta
cleanup_count = 0
cutoff_date = datetime.now() - timedelta(days=older_than_days)
try:
for record in self.media_registry.get_all_media_records():
episodes_to_remove = []
for episode in record.media_episodes:
if (
episode.download_status == DownloadStatus.FAILED
and episode.download_date < cutoff_date
):
episodes_to_remove.append(episode.episode_number)
for episode_number in episodes_to_remove:
record.media_episodes = [
ep for ep in record.media_episodes
ep
for ep in record.media_episodes
if ep.episode_number != episode_number
]
cleanup_count += 1
if episodes_to_remove:
self.media_registry.save_media_record(record)
logger.info(f"Cleaned up {cleanup_count} failed download records")
return cleanup_count
except Exception as e:
logger.error(f"Error during cleanup: {e}")
return 0
@@ -438,10 +465,23 @@ class DownloadService:
try:
media_record = self.media_registry.get_media_record(media_item.id)
if not media_record:
return {"total": 0, "downloaded": 0, "failed": 0, "queued": 0, "downloading": 0}
stats = {"total": 0, "downloaded": 0, "failed": 0, "queued": 0, "downloading": 0, "paused": 0}
return {
"total": 0,
"downloaded": 0,
"failed": 0,
"queued": 0,
"downloading": 0,
}
stats = {
"total": 0,
"downloaded": 0,
"failed": 0,
"queued": 0,
"downloading": 0,
"paused": 0,
}
for episode in media_record.media_episodes:
stats["total"] += 1
status = episode.download_status.value.lower()
@@ -455,26 +495,36 @@ class DownloadService:
stats["downloading"] += 1
elif status == "paused":
stats["paused"] += 1
return stats
except Exception as e:
logger.error(f"Error getting download progress: {e}")
return {"total": 0, "downloaded": 0, "failed": 0, "queued": 0, "downloading": 0}
return {
"total": 0,
"downloaded": 0,
"failed": 0,
"queued": 0,
"downloading": 0,
}
def _generate_episode_file_path(self, media_item: MediaItem, episode_number: str) -> Path:
def _generate_episode_file_path(
self, media_item: MediaItem, episode_number: str
) -> Path:
"""Generate the file path for a downloaded episode."""
# Use the download directory from config
base_dir = self.downloads_config.downloads_dir
# Create anime-specific directory
anime_title = media_item.title.english or media_item.title.romaji or "Unknown"
# Sanitize title for filesystem
safe_title = "".join(c for c in anime_title if c.isalnum() or c in (' ', '-', '_')).rstrip()
safe_title = "".join(
c for c in anime_title if c.isalnum() or c in (" ", "-", "_")
).rstrip()
anime_dir = base_dir / safe_title
# Generate filename (could use template from config in the future)
filename = f"Episode_{episode_number:0>2}.mp4"
return anime_dir / filename

View File

@@ -2,7 +2,7 @@ import json
import logging
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, TypedDict
from typing import Dict, Generator, List, Optional, TypedDict
from ....core.config.model import MediaRegistryConfig
from ....core.exceptions import FastAnimeError
@@ -12,7 +12,6 @@ from ....libs.media_api.types import (
MediaItem,
MediaSearchResult,
PageInfo,
UserListItem,
UserMediaListStatus,
)
from .models import (
@@ -586,8 +585,6 @@ class MediaRegistryService:
) -> list[tuple[int, str]]:
"""Get all episodes with a specific download status."""
try:
from .models import DownloadStatus
episodes = []
for record in self.get_all_media_records():
for episode in record.media_episodes:
@@ -602,8 +599,6 @@ class MediaRegistryService:
def get_download_statistics(self) -> dict:
"""Get comprehensive download statistics."""
try:
from .models import DownloadStatus
stats = {
"total_episodes": 0,
"downloaded": 0,

View File

@@ -1,5 +0,0 @@
"""CLI utilities for FastAnime."""
from .parser import parse_episode_range
__all__ = ["parse_episode_range"]

View File

@@ -3,6 +3,7 @@ import importlib
import click
# TODO: since command structure is pretty obvious default to only requiring mapping of command names to their function name(cause some have special names like import)
class LazyGroup(click.Group):
def __init__(self, root: str, *args, lazy_subcommands=None, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -4,12 +4,11 @@ from typing import Iterator
def parse_episode_range(
episode_range_str: str | None,
available_episodes: list[str]
episode_range_str: str | None, available_episodes: list[str]
) -> Iterator[str]:
"""
Parse an episode range string and return an iterator of episode numbers.
This function handles various episode range formats:
- Single episode: "5" -> episodes from index 5 onwards
- Range with start and end: "5:10" -> episodes from index 5 to 10 (exclusive)
@@ -17,18 +16,18 @@ def parse_episode_range(
- Start only: "5:" -> episodes from index 5 onwards
- End only: ":10" -> episodes from beginning to index 10
- All episodes: ":" -> all episodes
Args:
episode_range_str: The episode range string to parse (e.g., "5:10", "5:", ":10", "5")
available_episodes: List of available episode numbers/identifiers
Returns:
Iterator over the selected episode numbers
Raises:
ValueError: If the episode range format is invalid
IndexError: If the specified indices are out of range
Examples:
>>> episodes = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
>>> list(parse_episode_range("2:5", episodes))
@@ -43,14 +42,14 @@ def parse_episode_range(
if not episode_range_str:
# No range specified, return all episodes
return iter(available_episodes)
# Sort episodes numerically for consistent ordering
episodes = sorted(available_episodes, key=float)
if ":" in episode_range_str:
# Handle colon-separated ranges
parts = episode_range_str.split(":")
if len(parts) == 3:
# Format: start:end:step
start_str, end_str, step_str = parts
@@ -59,15 +58,15 @@ def parse_episode_range(
f"Invalid episode range format: '{episode_range_str}'. "
"When using 3 parts (start:end:step), all parts must be non-empty."
)
try:
start_idx = int(start_str)
end_idx = int(end_str)
step = int(step_str)
if step <= 0:
raise ValueError("Step value must be positive")
return iter(episodes[start_idx:end_idx:step])
except ValueError as e:
if "invalid literal" in str(e):
@@ -76,11 +75,11 @@ def parse_episode_range(
"All parts must be valid integers."
) from e
raise
elif len(parts) == 2:
# Format: start:end or start: or :end
start_str, end_str = parts
if start_str and end_str:
# Both start and end specified: start:end
try:
@@ -92,7 +91,7 @@ def parse_episode_range(
f"Invalid episode range format: '{episode_range_str}'. "
"Start and end must be valid integers."
) from e
elif start_str and not end_str:
# Only start specified: start:
try:
@@ -103,7 +102,7 @@ def parse_episode_range(
f"Invalid episode range format: '{episode_range_str}'. "
"Start must be a valid integer."
) from e
elif not start_str and end_str:
# Only end specified: :end
try:
@@ -114,7 +113,7 @@ def parse_episode_range(
f"Invalid episode range format: '{episode_range_str}'. "
"End must be a valid integer."
) from e
else:
# Both empty: ":"
return iter(episodes)

View File

@@ -1,21 +1,13 @@
import concurrent.futures
import logging
import os
import re
from hashlib import sha256
from pathlib import Path
from threading import Thread
from typing import List
import httpx
from ...core.utils import formatter
from typing import List, Optional
from ...core.config import AppConfig
from ...core.constants import APP_CACHE_DIR, PLATFORM, SCRIPTS_DIR
from ...core.utils.file import AtomicWriter
from ...libs.media_api.types import MediaItem
from . import ansi
from .preview_workers import PreviewWorkerManager
logger = logging.getLogger(__name__)
@@ -26,23 +18,109 @@ IMAGES_CACHE_DIR = PREVIEWS_CACHE_DIR / "images"
INFO_CACHE_DIR = PREVIEWS_CACHE_DIR / "info"
FZF_SCRIPTS_DIR = SCRIPTS_DIR / "fzf"
TEMPLATE_PREVIEW_SCRIPT = Path(str(FZF_SCRIPTS_DIR / "preview.template.sh")).read_text(
TEMPLATE_PREVIEW_SCRIPT = (FZF_SCRIPTS_DIR / "preview.template.sh").read_text(
encoding="utf-8"
)
TEMPLATE_INFO_SCRIPT = Path(str(FZF_SCRIPTS_DIR / "info.template.sh")).read_text(
DYNAMIC_PREVIEW_SCRIPT = (FZF_SCRIPTS_DIR / "dynamic_preview.template.sh").read_text(
encoding="utf-8"
)
TEMPLATE_EPISODE_INFO_SCRIPT = Path(
str(FZF_SCRIPTS_DIR / "episode-info.template.sh")
).read_text(encoding="utf-8")
EPISODE_PATTERN = re.compile(r"^Episode\s+(\d+)\s-\s.*")
# Global preview worker manager instance
_preview_manager: Optional[PreviewWorkerManager] = None
def create_preview_context():
"""
Create a context manager for preview operations.
This can be used in menu functions to ensure proper cleanup:
```python
with create_preview_context() as preview_ctx:
preview_script = preview_ctx.get_anime_preview(items, titles, config)
# ... use preview_script
# Workers are automatically cleaned up here
```
Returns:
PreviewContext: A context manager for preview operations
"""
return PreviewContext()
class PreviewContext:
"""Context manager for preview operations with automatic cleanup."""
def __init__(self):
self._manager = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._manager:
try:
self._manager.shutdown_all(wait=False, timeout=3.0)
except Exception as e:
logger.warning(f"Failed to cleanup preview context: {e}")
def get_anime_preview(
self, items: List[MediaItem], titles: List[str], config: AppConfig
) -> str:
"""Get anime preview script with managed workers."""
if not self._manager:
self._manager = _get_preview_manager()
return get_anime_preview(items, titles, config)
def get_episode_preview(
self, episodes: List[str], media_item: MediaItem, config: AppConfig
) -> str:
"""Get episode preview script with managed workers."""
if not self._manager:
self._manager = _get_preview_manager()
return get_episode_preview(episodes, media_item, config)
def get_dynamic_anime_preview(self, config: AppConfig) -> str:
"""Get dynamic anime preview script for search functionality."""
if not self._manager:
self._manager = _get_preview_manager()
return get_dynamic_anime_preview(config)
def cancel_all_tasks(self) -> int:
"""Cancel all running preview tasks."""
if not self._manager:
return 0
cancelled = 0
if self._manager._preview_worker:
cancelled += self._manager._preview_worker.cancel_all_tasks()
if self._manager._episode_worker:
cancelled += self._manager._episode_worker.cancel_all_tasks()
return cancelled
def get_status(self) -> dict:
"""Get status of workers in this context."""
if self._manager:
return self._manager.get_status()
return {"preview_worker": None, "episode_worker": None}
def get_anime_preview(
items: List[MediaItem], titles: List[str], config: AppConfig
) -> str:
"""
Generate anime preview script and start background caching.
Args:
items: List of media items to preview
titles: Corresponding titles for each media item
config: Application configuration
Returns:
Preview script content for fzf
"""
# Ensure cache directories exist on startup
IMAGES_CACHE_DIR.mkdir(parents=True, exist_ok=True)
INFO_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@@ -52,8 +130,15 @@ def get_anime_preview(
preview_script = TEMPLATE_PREVIEW_SCRIPT
# Start the non-blocking background Caching
Thread(target=_cache_worker, args=(items, titles, config), daemon=True).start()
# Start the managed background caching
try:
preview_manager = _get_preview_manager()
worker = preview_manager.get_preview_worker()
worker.cache_anime_previews(items, titles, config)
logger.debug("Started background caching for anime previews")
except Exception as e:
logger.error(f"Failed to start background caching: {e}")
# Continue with script generation even if caching fails
# Prepare values to inject into the template
path_sep = "\\" if PLATFORM == "win32" else "/"
@@ -80,97 +165,20 @@ def get_anime_preview(
return preview_script
def _cache_worker(media_items: List[MediaItem], titles: List[str], config: AppConfig):
"""The background task that fetches and saves all necessary preview data."""
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
for media_item, title_str in zip(media_items, titles):
hash_id = _get_cache_hash(title_str)
if config.general.preview in ("full", "image") and media_item.cover_image:
if not (IMAGES_CACHE_DIR / f"{hash_id}.png").exists():
executor.submit(
_save_image_from_url, media_item.cover_image.large, hash_id
)
if config.general.preview in ("full", "text"):
# TODO: Come up with a better caching pattern for now just let it be remade
if not (INFO_CACHE_DIR / hash_id).exists() or True:
info_text = _populate_info_template(media_item, config)
executor.submit(_save_info_text, info_text, hash_id)
def _populate_info_template(media_item: MediaItem, config: AppConfig) -> str:
"""
Takes the info.sh template and injects formatted, shell-safe data.
"""
info_script = TEMPLATE_INFO_SCRIPT
description = formatter.clean_html(
media_item.description or "No description available."
)
# Escape all variables before injecting them into the script
replacements = {
"TITLE": formatter.shell_safe(
media_item.title.english or media_item.title.romaji
),
"STATUS": formatter.shell_safe(media_item.status.value),
"FORMAT": formatter.shell_safe(media_item.format.value),
"NEXT_EPISODE": formatter.shell_safe(
f"Episode {media_item.next_airing.episode} on {formatter.format_date(media_item.next_airing.airing_at, '%A, %d %B %Y at %X)')}"
if media_item.next_airing
else "N/A"
),
"EPISODES": formatter.shell_safe(str(media_item.episodes)),
"DURATION": formatter.shell_safe(
formatter.format_media_duration(media_item.duration)
),
"SCORE": formatter.shell_safe(
formatter.format_score_stars_full(media_item.average_score)
),
"FAVOURITES": formatter.shell_safe(
formatter.format_number_with_commas(media_item.favourites)
),
"POPULARITY": formatter.shell_safe(
formatter.format_number_with_commas(media_item.popularity)
),
"GENRES": formatter.shell_safe(
formatter.format_list_with_commas([v.value for v in media_item.genres])
),
"TAGS": formatter.shell_safe(
formatter.format_list_with_commas([t.name.value for t in media_item.tags])
),
"STUDIOS": formatter.shell_safe(
formatter.format_list_with_commas(
[t.name for t in media_item.studios if t.name]
)
),
"SYNONYMNS": formatter.shell_safe(
formatter.format_list_with_commas(media_item.synonymns)
),
"USER_STATUS": formatter.shell_safe(
media_item.user_status.status.value
if media_item.user_status and media_item.user_status.status
else "NOT_ON_LIST"
),
"USER_PROGRESS": formatter.shell_safe(
f"Episode {media_item.user_status.progress}"
if media_item.user_status
else "0"
),
"START_DATE": formatter.shell_safe(
formatter.format_date(media_item.start_date)
),
"END_DATE": formatter.shell_safe(formatter.format_date(media_item.end_date)),
"SYNOPSIS": formatter.shell_safe(description),
}
for key, value in replacements.items():
info_script = info_script.replace(f"{{{key}}}", value)
return info_script
def get_episode_preview(
episodes: List[str], media_item: MediaItem, config: AppConfig
) -> str:
"""
Generate episode preview script and start background caching.
Args:
episodes: List of episode identifiers
media_item: Media item containing episode data
config: Application configuration
Returns:
Preview script content for fzf
"""
IMAGES_CACHE_DIR.mkdir(parents=True, exist_ok=True)
INFO_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@@ -178,10 +186,16 @@ def get_episode_preview(
SEPARATOR_COLOR = config.fzf.preview_separator_color.split(",")
preview_script = TEMPLATE_PREVIEW_SCRIPT
# Start background caching for episodes
Thread(
target=_episode_cache_worker, args=(episodes, media_item, config), daemon=True
).start()
# Start managed background caching for episodes
try:
preview_manager = _get_preview_manager()
worker = preview_manager.get_episode_worker()
worker.cache_episode_previews(episodes, media_item, config)
logger.debug("Started background caching for episode previews")
except Exception as e:
logger.error(f"Failed to start episode background caching: {e}")
# Continue with script generation even if caching fails
# Prepare values to inject into the template
path_sep = "\\" if PLATFORM == "win32" else "/"
@@ -208,107 +222,86 @@ def get_episode_preview(
return preview_script
def _episode_cache_worker(
episodes: List[str], media_item: MediaItem, config: AppConfig
):
"""Background task that fetches and saves episode preview data."""
streaming_episodes = media_item.streaming_episodes
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
for episode_str in episodes:
hash_id = _get_cache_hash(
f"{media_item.title.english}_Episode_{episode_str}"
)
# Find matching streaming episode
title = None
thumbnail = None
if ep := streaming_episodes.get(episode_str):
title = ep.title
thumbnail = ep.thumbnail
# Fallback if no streaming episode found
if not title:
title = f"Episode {episode_str}"
# fallback
if not thumbnail and media_item.cover_image:
thumbnail = media_item.cover_image.large
# Download thumbnail if available
if thumbnail:
executor.submit(_save_image_from_url, thumbnail, hash_id)
# Generate and save episode info
episode_info = _populate_episode_info_template(config, title, media_item)
executor.submit(_save_info_text, episode_info, hash_id)
def _populate_episode_info_template(
config: AppConfig, title: str, media_item: MediaItem
) -> str:
def get_dynamic_anime_preview(config: AppConfig) -> str:
"""
Takes the episode_info.sh template and injects episode-specific formatted data.
Generate dynamic anime preview script for search functionality.
This is different from regular anime preview because:
1. We don't have media items upfront
2. The preview needs to work with search results as they come in
3. Preview is handled entirely in shell by parsing JSON results
Args:
config: Application configuration
Returns:
Preview script content for fzf dynamic search
"""
episode_info_script = TEMPLATE_EPISODE_INFO_SCRIPT
# Ensure cache directories exist
IMAGES_CACHE_DIR.mkdir(parents=True, exist_ok=True)
INFO_CACHE_DIR.mkdir(parents=True, exist_ok=True)
HEADER_COLOR = config.fzf.preview_header_color.split(",")
SEPARATOR_COLOR = config.fzf.preview_separator_color.split(",")
# Use the dynamic preview script template
preview_script = DYNAMIC_PREVIEW_SCRIPT
# We need to return the path to the search results file
from ...core.constants import APP_CACHE_DIR
search_cache_dir = APP_CACHE_DIR / "search"
search_results_file = search_cache_dir / "current_search_results.json"
# Prepare values to inject into the template
path_sep = "\\" if PLATFORM == "win32" else "/"
# Format the template with the dynamic values
replacements = {
"TITLE": formatter.shell_safe(title),
"NEXT_EPISODE": formatter.shell_safe(
f"Episode {media_item.next_airing.episode} on {formatter.format_date(media_item.next_airing.airing_at, '%A, %d %B %Y at %X)')}"
if media_item.next_airing
else "N/A"
),
"DURATION": formatter.format_media_duration(media_item.duration),
"STATUS": formatter.shell_safe(media_item.status.value),
"EPISODES": formatter.shell_safe(str(media_item.episodes)),
"USER_STATUS": formatter.shell_safe(
media_item.user_status.status.value
if media_item.user_status and media_item.user_status.status
else "NOT_ON_LIST"
),
"USER_PROGRESS": formatter.shell_safe(
f"Episode {media_item.user_status.progress}"
if media_item.user_status
else "0"
),
"START_DATE": formatter.shell_safe(
formatter.format_date(media_item.start_date)
),
"END_DATE": formatter.shell_safe(formatter.format_date(media_item.end_date)),
"PREVIEW_MODE": config.general.preview,
"IMAGE_CACHE_PATH": str(IMAGES_CACHE_DIR),
"INFO_CACHE_PATH": str(INFO_CACHE_DIR),
"PATH_SEP": path_sep,
"IMAGE_RENDERER": config.general.image_renderer,
"SEARCH_RESULTS_FILE": str(search_results_file),
# Color codes
"C_TITLE": ansi.get_true_fg(HEADER_COLOR, bold=True),
"C_KEY": ansi.get_true_fg(HEADER_COLOR, bold=True),
"C_VALUE": ansi.get_true_fg(HEADER_COLOR, bold=True),
"C_RULE": ansi.get_true_fg(SEPARATOR_COLOR, bold=True),
"RESET": ansi.RESET,
}
for key, value in replacements.items():
episode_info_script = episode_info_script.replace(f"{{{key}}}", value)
preview_script = preview_script.replace(f"{{{key}}}", value)
return episode_info_script
return preview_script
def _get_cache_hash(text: str) -> str:
"""Generates a consistent SHA256 hash for a given string to use as a filename."""
return sha256(text.encode("utf-8")).hexdigest()
def _get_preview_manager() -> PreviewWorkerManager:
"""Get or create the global preview worker manager."""
global _preview_manager
if _preview_manager is None:
_preview_manager = PreviewWorkerManager(IMAGES_CACHE_DIR, INFO_CACHE_DIR)
return _preview_manager
def _save_image_from_url(url: str, hash_id: str):
"""Downloads an image using httpx and saves it to the cache."""
image_path = IMAGES_CACHE_DIR / f"{hash_id}.png"
try:
with httpx.stream("GET", url, follow_redirects=True, timeout=20) as response:
response.raise_for_status()
with AtomicWriter(image_path, "wb", encoding=None) as f:
chunks = b""
for chunk in response.iter_bytes():
chunks += chunk
f.write(chunks)
except Exception as e:
logger.error(f"Failed to download image {url}: {e}")
def shutdown_preview_workers(wait: bool = True, timeout: Optional[float] = 5.0) -> None:
"""
Shutdown all preview workers.
Args:
wait: Whether to wait for tasks to complete
timeout: Maximum time to wait for shutdown
"""
global _preview_manager
if _preview_manager:
_preview_manager.shutdown_all(wait=wait, timeout=timeout)
_preview_manager = None
def _save_info_text(info_text: str, hash_id: str):
"""Saves pre-formatted text to the info cache."""
try:
info_path = INFO_CACHE_DIR / hash_id
with AtomicWriter(info_path) as f:
f.write(info_text)
except IOError as e:
logger.error(f"Failed to write info cache for {hash_id}: {e}")
def get_preview_worker_status() -> dict:
"""Get status of all preview workers."""
global _preview_manager
if _preview_manager:
return _preview_manager.get_status()
return {"preview_worker": None, "episode_worker": None}

View File

@@ -0,0 +1,475 @@
"""
Preview-specific background workers for caching anime and episode data.
This module provides specialized workers for handling anime preview caching,
including image downloads and info text generation with proper lifecycle management.
"""
import logging
from typing import List, Optional
import httpx
from ...core.constants import SCRIPTS_DIR
from ...core.config import AppConfig
from ...core.utils import formatter
from ...core.utils.concurrency import (
ManagedBackgroundWorker,
WorkerTask,
thread_manager,
)
from ...core.utils.file import AtomicWriter
from ...libs.media_api.types import MediaItem
logger = logging.getLogger(__name__)
FZF_SCRIPTS_DIR = SCRIPTS_DIR / "fzf"
TEMPLATE_INFO_SCRIPT = (FZF_SCRIPTS_DIR / "info.template.sh").read_text(
encoding="utf-8"
)
TEMPLATE_EPISODE_INFO_SCRIPT = (FZF_SCRIPTS_DIR / "episode-info.template.sh").read_text(
encoding="utf-8"
)
class PreviewCacheWorker(ManagedBackgroundWorker):
"""
Specialized background worker for caching anime preview data.
Handles downloading images and generating info text for anime previews
with proper error handling and resource management.
"""
def __init__(self, images_cache_dir, info_cache_dir, max_workers: int = 10):
"""
Initialize the preview cache worker.
Args:
images_cache_dir: Directory to cache images
info_cache_dir: Directory to cache info text
max_workers: Maximum number of concurrent workers
"""
super().__init__(max_workers=max_workers, name="PreviewCacheWorker")
self.images_cache_dir = images_cache_dir
self.info_cache_dir = info_cache_dir
self._http_client: Optional[httpx.Client] = None
def start(self) -> None:
"""Start the worker and initialize HTTP client."""
super().start()
self._http_client = httpx.Client(
timeout=20.0,
follow_redirects=True,
limits=httpx.Limits(max_connections=self.max_workers),
)
logger.debug("PreviewCacheWorker HTTP client initialized")
def shutdown(self, wait: bool = True, timeout: Optional[float] = 30.0) -> None:
"""Shutdown the worker and cleanup HTTP client."""
super().shutdown(wait=wait, timeout=timeout)
if self._http_client:
self._http_client.close()
self._http_client = None
logger.debug("PreviewCacheWorker HTTP client closed")
def cache_anime_previews(
self, media_items: List[MediaItem], titles: List[str], config: AppConfig
) -> None:
"""
Cache preview data for multiple anime items.
Args:
media_items: List of media items to cache
titles: Corresponding titles for each media item
config: Application configuration
"""
if not self.is_running():
raise RuntimeError("PreviewCacheWorker is not running")
for media_item, title_str in zip(media_items, titles):
hash_id = self._get_cache_hash(title_str)
# Submit image download task if needed
if config.general.preview in ("full", "image") and media_item.cover_image:
image_path = self.images_cache_dir / f"{hash_id}.png"
if not image_path.exists():
self.submit_function(
self._download_and_save_image,
media_item.cover_image.large,
hash_id,
)
# Submit info generation task if needed
if config.general.preview in ("full", "text"):
info_path = self.info_cache_dir / hash_id
if not info_path.exists():
info_text = self._generate_info_text(media_item, config)
self.submit_function(self._save_info_text, info_text, hash_id)
def _download_and_save_image(self, url: str, hash_id: str) -> None:
"""Download an image and save it to cache."""
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
image_path = self.images_cache_dir / f"{hash_id}.png"
try:
with self._http_client.stream("GET", url) as response:
response.raise_for_status()
with AtomicWriter(image_path, "wb", encoding=None) as f:
for chunk in response.iter_bytes():
f.write(chunk)
logger.debug(f"Successfully cached image: {hash_id}")
except Exception as e:
logger.error(f"Failed to download image {url}: {e}")
raise
def _generate_info_text(self, media_item: MediaItem, config: AppConfig) -> str:
"""Generate formatted info text for a media item."""
# Import here to avoid circular imports
info_script = TEMPLATE_INFO_SCRIPT
description = formatter.clean_html(
media_item.description or "No description available."
)
# Escape all variables before injecting them into the script
replacements = {
"TITLE": formatter.shell_safe(
media_item.title.english or media_item.title.romaji
),
"STATUS": formatter.shell_safe(media_item.status.value),
"FORMAT": formatter.shell_safe(media_item.format.value),
"NEXT_EPISODE": formatter.shell_safe(
f"Episode {media_item.next_airing.episode} on {formatter.format_date(media_item.next_airing.airing_at, '%A, %d %B %Y at %X)')}"
if media_item.next_airing
else "N/A"
),
"EPISODES": formatter.shell_safe(str(media_item.episodes)),
"DURATION": formatter.shell_safe(
formatter.format_media_duration(media_item.duration)
),
"SCORE": formatter.shell_safe(
formatter.format_score_stars_full(media_item.average_score)
),
"FAVOURITES": formatter.shell_safe(
formatter.format_number_with_commas(media_item.favourites)
),
"POPULARITY": formatter.shell_safe(
formatter.format_number_with_commas(media_item.popularity)
),
"GENRES": formatter.shell_safe(
formatter.format_list_with_commas([v.value for v in media_item.genres])
),
"TAGS": formatter.shell_safe(
formatter.format_list_with_commas(
[t.name.value for t in media_item.tags]
)
),
"STUDIOS": formatter.shell_safe(
formatter.format_list_with_commas(
[t.name for t in media_item.studios if t.name]
)
),
"SYNONYMNS": formatter.shell_safe(
formatter.format_list_with_commas(media_item.synonymns)
),
"USER_STATUS": formatter.shell_safe(
media_item.user_status.status.value
if media_item.user_status and media_item.user_status.status
else "NOT_ON_LIST"
),
"USER_PROGRESS": formatter.shell_safe(
f"Episode {media_item.user_status.progress}"
if media_item.user_status
else "0"
),
"START_DATE": formatter.shell_safe(
formatter.format_date(media_item.start_date)
),
"END_DATE": formatter.shell_safe(
formatter.format_date(media_item.end_date)
),
"SYNOPSIS": formatter.shell_safe(description),
}
for key, value in replacements.items():
info_script = info_script.replace(f"{{{key}}}", value)
return info_script
def _save_info_text(self, info_text: str, hash_id: str) -> None:
"""Save info text to cache."""
try:
info_path = self.info_cache_dir / hash_id
with AtomicWriter(info_path) as f:
f.write(info_text)
logger.debug(f"Successfully cached info: {hash_id}")
except IOError as e:
logger.error(f"Failed to write info cache for {hash_id}: {e}")
raise
def _get_cache_hash(self, text: str) -> str:
"""Generate a cache hash for the given text."""
from hashlib import sha256
return sha256(text.encode("utf-8")).hexdigest()
def _on_task_completed(self, task: WorkerTask, future) -> None:
"""Handle task completion with enhanced logging."""
super()._on_task_completed(task, future)
if future.exception():
logger.warning(f"Preview cache task failed: {future.exception()}")
else:
logger.debug("Preview cache task completed successfully")
class EpisodeCacheWorker(ManagedBackgroundWorker):
"""
Specialized background worker for caching episode preview data.
Handles episode-specific caching including thumbnails and episode info
with proper error handling and resource management.
"""
def __init__(self, images_cache_dir, info_cache_dir, max_workers: int = 5):
"""
Initialize the episode cache worker.
Args:
images_cache_dir: Directory to cache images
info_cache_dir: Directory to cache info text
max_workers: Maximum number of concurrent workers
"""
super().__init__(max_workers=max_workers, name="EpisodeCacheWorker")
self.images_cache_dir = images_cache_dir
self.info_cache_dir = info_cache_dir
self._http_client: Optional[httpx.Client] = None
def start(self) -> None:
"""Start the worker and initialize HTTP client."""
super().start()
self._http_client = httpx.Client(
timeout=20.0,
follow_redirects=True,
limits=httpx.Limits(max_connections=self.max_workers),
)
logger.debug("EpisodeCacheWorker HTTP client initialized")
def shutdown(self, wait: bool = True, timeout: Optional[float] = 30.0) -> None:
"""Shutdown the worker and cleanup HTTP client."""
super().shutdown(wait=wait, timeout=timeout)
if self._http_client:
self._http_client.close()
self._http_client = None
logger.debug("EpisodeCacheWorker HTTP client closed")
def cache_episode_previews(
self, episodes: List[str], media_item: MediaItem, config: AppConfig
) -> None:
"""
Cache preview data for multiple episodes.
Args:
episodes: List of episode identifiers
media_item: Media item containing episode data
config: Application configuration
"""
if not self.is_running():
raise RuntimeError("EpisodeCacheWorker is not running")
streaming_episodes = media_item.streaming_episodes
for episode_str in episodes:
hash_id = self._get_cache_hash(
f"{media_item.title.english}_Episode_{episode_str}"
)
# Find episode data
episode_data = streaming_episodes.get(episode_str)
title = episode_data.title if episode_data else f"Episode {episode_str}"
thumbnail = None
if episode_data and episode_data.thumbnail:
thumbnail = episode_data.thumbnail
elif media_item.cover_image:
thumbnail = media_item.cover_image.large
# Submit thumbnail download task
if thumbnail:
self.submit_function(self._download_and_save_image, thumbnail, hash_id)
# Submit episode info generation task
episode_info = self._generate_episode_info(config, title, media_item)
self.submit_function(self._save_info_text, episode_info, hash_id)
def _download_and_save_image(self, url: str, hash_id: str) -> None:
"""Download an image and save it to cache."""
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
image_path = self.images_cache_dir / f"{hash_id}.png"
try:
with self._http_client.stream("GET", url) as response:
response.raise_for_status()
with AtomicWriter(image_path, "wb", encoding=None) as f:
for chunk in response.iter_bytes():
f.write(chunk)
logger.debug(f"Successfully cached episode image: {hash_id}")
except Exception as e:
logger.error(f"Failed to download episode image {url}: {e}")
raise
def _generate_episode_info(
self, config: AppConfig, title: str, media_item: MediaItem
) -> str:
"""Generate formatted episode info text."""
episode_info_script = TEMPLATE_EPISODE_INFO_SCRIPT
replacements = {
"TITLE": formatter.shell_safe(title),
"NEXT_EPISODE": formatter.shell_safe(
f"Episode {media_item.next_airing.episode} on {formatter.format_date(media_item.next_airing.airing_at, '%A, %d %B %Y at %X)')}"
if media_item.next_airing
else "N/A"
),
"DURATION": formatter.format_media_duration(media_item.duration),
"STATUS": formatter.shell_safe(media_item.status.value),
"EPISODES": formatter.shell_safe(str(media_item.episodes)),
"USER_STATUS": formatter.shell_safe(
media_item.user_status.status.value
if media_item.user_status and media_item.user_status.status
else "NOT_ON_LIST"
),
"USER_PROGRESS": formatter.shell_safe(
f"Episode {media_item.user_status.progress}"
if media_item.user_status
else "0"
),
"START_DATE": formatter.shell_safe(
formatter.format_date(media_item.start_date)
),
"END_DATE": formatter.shell_safe(
formatter.format_date(media_item.end_date)
),
}
for key, value in replacements.items():
episode_info_script = episode_info_script.replace(f"{{{key}}}", value)
return episode_info_script
def _save_info_text(self, info_text: str, hash_id: str) -> None:
"""Save episode info text to cache."""
try:
info_path = self.info_cache_dir / hash_id
with AtomicWriter(info_path) as f:
f.write(info_text)
logger.debug(f"Successfully cached episode info: {hash_id}")
except IOError as e:
logger.error(f"Failed to write episode info cache for {hash_id}: {e}")
raise
def _get_cache_hash(self, text: str) -> str:
"""Generate a cache hash for the given text."""
from hashlib import sha256
return sha256(text.encode("utf-8")).hexdigest()
def _on_task_completed(self, task: WorkerTask, future) -> None:
"""Handle task completion with enhanced logging."""
super()._on_task_completed(task, future)
if future.exception():
logger.warning(f"Episode cache task failed: {future.exception()}")
else:
logger.debug("Episode cache task completed successfully")
class PreviewWorkerManager:
"""
High-level manager for preview caching workers.
Provides a simple interface for managing both anime and episode preview
caching workers with automatic lifecycle management.
"""
def __init__(self, images_cache_dir, info_cache_dir):
"""
Initialize the preview worker manager.
Args:
images_cache_dir: Directory to cache images
info_cache_dir: Directory to cache info text
"""
self.images_cache_dir = images_cache_dir
self.info_cache_dir = info_cache_dir
self._preview_worker: Optional[PreviewCacheWorker] = None
self._episode_worker: Optional[EpisodeCacheWorker] = None
def get_preview_worker(self) -> PreviewCacheWorker:
"""Get or create the preview cache worker."""
if self._preview_worker is None or not self._preview_worker.is_running():
if self._preview_worker:
# Clean up old worker
thread_manager.shutdown_worker("preview_cache_worker")
self._preview_worker = PreviewCacheWorker(
self.images_cache_dir, self.info_cache_dir
)
self._preview_worker.start()
thread_manager.register_worker("preview_cache_worker", self._preview_worker)
return self._preview_worker
def get_episode_worker(self) -> EpisodeCacheWorker:
"""Get or create the episode cache worker."""
if self._episode_worker is None or not self._episode_worker.is_running():
if self._episode_worker:
# Clean up old worker
thread_manager.shutdown_worker("episode_cache_worker")
self._episode_worker = EpisodeCacheWorker(
self.images_cache_dir, self.info_cache_dir
)
self._episode_worker.start()
thread_manager.register_worker("episode_cache_worker", self._episode_worker)
return self._episode_worker
def shutdown_all(self, wait: bool = True, timeout: Optional[float] = 30.0) -> None:
"""Shutdown all managed workers."""
thread_manager.shutdown_worker(
"preview_cache_worker", wait=wait, timeout=timeout
)
thread_manager.shutdown_worker(
"episode_cache_worker", wait=wait, timeout=timeout
)
self._preview_worker = None
self._episode_worker = None
def get_status(self) -> dict:
"""Get status of all managed workers."""
return {
"preview_worker": self._preview_worker.get_completion_stats()
if self._preview_worker
else None,
"episode_worker": self._episode_worker.get_completion_stats()
if self._episode_worker
else None,
}
def __enter__(self):
"""Context manager entry - workers are created on demand."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit with automatic cleanup."""
self.shutdown_all(wait=False, timeout=5.0)

View File

@@ -9,7 +9,7 @@ import sys
from httpx import get
from rich import print
from ...core.constants import AUTHOR, GIT_REPO, PROJECT_NAME_LOWER, __version__
from ...core.constants import AUTHOR, GIT_REPO, PROJECT_NAME_LOWER, __version__
API_URL = f"https://api.{GIT_REPO}/repos/{AUTHOR}/{PROJECT_NAME_LOWER}/releases/latest"
@@ -98,7 +98,7 @@ def update_app(force=False):
process = subprocess.run(
[NIX, "profile", "upgrade", PROJECT_NAME_LOWER], check=False
)
elif is_git_repo(AUTHOR, PROJECT_NAME_LOWER) :
elif is_git_repo(AUTHOR, PROJECT_NAME_LOWER):
GIT_EXECUTABLE = shutil.which("git")
args = [
GIT_EXECUTABLE,
@@ -117,7 +117,9 @@ def update_app(force=False):
)
elif UV := shutil.which("uv"):
process = subprocess.run([UV, "tool", "upgrade", PROJECT_NAME_LOWER], check=False)
process = subprocess.run(
[UV, "tool", "upgrade", PROJECT_NAME_LOWER], check=False
)
elif PIPX := shutil.which("pipx"):
process = subprocess.run([PIPX, "upgrade", PROJECT_NAME_LOWER], check=False)
else:

View File

@@ -34,14 +34,14 @@ logger = logging.getLogger(__name__)
class DefaultDownloader(BaseDownloader):
"""Default downloader that uses httpx for downloads without yt-dlp dependency."""
def download(self, params: DownloadParams) -> DownloadResult:
"""Download video and optionally subtitles, returning detailed results."""
try:
video_path = None
sub_paths = []
merged_path = None
if TORRENT_REGEX.match(params.url):
from .torrents import download_torrent_with_webtorrent_cli
@@ -51,24 +51,26 @@ class DefaultDownloader(BaseDownloader):
dest_dir.mkdir(parents=True, exist_ok=True)
video_path = dest_dir / episode_title
video_path = download_torrent_with_webtorrent_cli(video_path, params.url)
video_path = download_torrent_with_webtorrent_cli(
video_path, params.url
)
else:
video_path = self._download_video(params)
if params.subtitles:
sub_paths = self._download_subs(params)
if params.merge:
merged_path = self._merge_subtitles(params, video_path, sub_paths)
return DownloadResult(
success=True,
video_path=video_path,
subtitle_paths=sub_paths,
merged_path=merged_path,
anime_title=params.anime_title,
episode_title=params.episode_title
episode_title=params.episode_title,
)
except KeyboardInterrupt:
print()
print("Aborted!")
@@ -76,7 +78,7 @@ class DefaultDownloader(BaseDownloader):
success=False,
error_message="Download aborted by user",
anime_title=params.anime_title,
episode_title=params.episode_title
episode_title=params.episode_title,
)
except Exception as e:
logger.error(f"Download failed: {e}")
@@ -84,43 +86,41 @@ class DefaultDownloader(BaseDownloader):
success=False,
error_message=str(e),
anime_title=params.anime_title,
episode_title=params.episode_title
episode_title=params.episode_title,
)
def _download_video(self, params: DownloadParams) -> Path:
"""Download video using httpx with progress tracking."""
anime_title = sanitize_filename(params.anime_title)
episode_title = sanitize_filename(params.episode_title)
dest_dir = self.config.downloads_dir / anime_title
dest_dir.mkdir(parents=True, exist_ok=True)
# Get file extension from URL or headers
file_extension = self._get_file_extension(params.url, params.headers)
if params.force_unknown_ext and not file_extension:
file_extension = ".unknown_video"
elif not file_extension:
file_extension = ".mp4" # default fallback
video_path = dest_dir / f"{episode_title}{file_extension}"
# Check if file already exists
if video_path.exists() and not params.prompt:
logger.info(f"File already exists: {video_path}")
return video_path
elif video_path.exists() and params.prompt:
if not Confirm.ask(f"File exists: {video_path.name}. Overwrite?", default=False):
if not Confirm.ask(
f"File exists: {video_path.name}. Overwrite?", default=False
):
return video_path
# Download with progress tracking
self._download_with_progress(
params.url,
video_path,
params.headers,
params.silent,
params.progress_hooks
params.url, video_path, params.headers, params.silent, params.progress_hooks
)
# Handle unknown video extension normalization
if video_path.suffix == ".unknown_video":
normalized_path = video_path.with_suffix(".mp4")
@@ -128,7 +128,7 @@ class DefaultDownloader(BaseDownloader):
shutil.move(video_path, normalized_path)
print("Successfully normalized path")
return normalized_path
return video_path
def _get_file_extension(self, url: str, headers: dict) -> str:
@@ -136,55 +136,64 @@ class DefaultDownloader(BaseDownloader):
# First try to get from URL
parsed_url = urllib.parse.urlparse(url)
path = parsed_url.path
if path and '.' in path:
if path and "." in path:
return Path(path).suffix
# Try to get from response headers
try:
with self.client.stream('HEAD', url, headers=headers) as response:
content_type = response.headers.get('content-type', '')
if 'video/mp4' in content_type:
return '.mp4'
elif 'video/webm' in content_type:
return '.webm'
elif 'video/x-matroska' in content_type:
return '.mkv'
elif 'video/x-msvideo' in content_type:
return '.avi'
elif 'video/quicktime' in content_type:
return '.mov'
with self.client.stream("HEAD", url, headers=headers) as response:
content_type = response.headers.get("content-type", "")
if "video/mp4" in content_type:
return ".mp4"
elif "video/webm" in content_type:
return ".webm"
elif "video/x-matroska" in content_type:
return ".mkv"
elif "video/x-msvideo" in content_type:
return ".avi"
elif "video/quicktime" in content_type:
return ".mov"
# Try content-disposition header
content_disposition = response.headers.get('content-disposition', '')
if 'filename=' in content_disposition:
filename = content_disposition.split('filename=')[1].strip('"\'')
content_disposition = response.headers.get("content-disposition", "")
if "filename=" in content_disposition:
filename = content_disposition.split("filename=")[1].strip("\"'")
return Path(filename).suffix
except Exception:
pass
return ""
def _download_with_progress(self, url: str, output_path: Path, headers: dict, silent: bool, progress_hooks: list | None = None):
def _download_with_progress(
self,
url: str,
output_path: Path,
headers: dict,
silent: bool,
progress_hooks: list | None = None,
):
"""Download file with rich progress bar and progress hooks."""
progress_hooks = progress_hooks or []
# Always show download start message
print(f"[cyan]Starting download of {output_path.name}...[/]")
try:
with self.client.stream('GET', url, headers=headers) as response:
with self.client.stream("GET", url, headers=headers) as response:
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
total_size = int(response.headers.get("content-length", 0))
downloaded = 0
# Initialize progress display - always show progress
progress = None
task_id = None
if total_size > 0:
progress = Progress(
TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
TextColumn(
"[bold blue]{task.fields[filename]}", justify="right"
),
BarColumn(bar_width=None),
"[progress.percentage]{task.percentage:>3.1f}%",
"",
@@ -197,75 +206,77 @@ class DefaultDownloader(BaseDownloader):
else:
# Progress without total size (indeterminate)
progress = Progress(
TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
TextColumn(
"[bold blue]{task.fields[filename]}", justify="right"
),
TextColumn("[green]{task.completed} bytes"),
"",
TransferSpeedColumn(),
)
progress.start()
task_id = progress.add_task(
"download",
filename=output_path.name,
total=total_size if total_size > 0 else None
"download",
filename=output_path.name,
total=total_size if total_size > 0 else None,
)
try:
with open(output_path, 'wb') as f:
with open(output_path, "wb") as f:
for chunk in response.iter_bytes(chunk_size=8192):
if chunk:
f.write(chunk)
chunk_size = len(chunk)
downloaded += chunk_size
# Always update progress bar
if progress is not None and task_id is not None:
progress.update(task_id, advance=chunk_size)
# Call progress hooks
if progress_hooks:
progress_info = {
'downloaded_bytes': downloaded,
'total_bytes': total_size,
'filename': output_path.name,
'status': 'downloading'
"downloaded_bytes": downloaded,
"total_bytes": total_size,
"filename": output_path.name,
"status": "downloading",
}
for hook in progress_hooks:
try:
hook(progress_info)
except Exception as e:
logger.warning(f"Progress hook failed: {e}")
finally:
if progress:
progress.stop()
# Always show completion message
print(f"[green]✓ Download completed: {output_path.name}[/]")
# Call completion hooks
if progress_hooks:
completion_info = {
'downloaded_bytes': downloaded,
'total_bytes': total_size or downloaded,
'filename': output_path.name,
'status': 'finished'
"downloaded_bytes": downloaded,
"total_bytes": total_size or downloaded,
"filename": output_path.name,
"status": "finished",
}
for hook in progress_hooks:
try:
hook(completion_info)
except Exception as e:
logger.warning(f"Progress hook failed: {e}")
except httpx.HTTPError as e:
# Call error hooks
if progress_hooks:
error_info = {
'downloaded_bytes': 0,
'total_bytes': 0,
'filename': output_path.name,
'status': 'error',
'error': str(e)
"downloaded_bytes": 0,
"total_bytes": 0,
"filename": output_path.name,
"status": "error",
"error": str(e),
}
for hook in progress_hooks:
try:
@@ -280,12 +291,12 @@ class DefaultDownloader(BaseDownloader):
episode_title = sanitize_filename(params.episode_title)
base = self.config.downloads_dir / anime_title
downloaded_subs = []
for i, sub_url in enumerate(params.subtitles):
try:
response = self.client.get(sub_url, headers=params.headers)
response.raise_for_status()
# Determine filename
filename = get_remote_filename(response)
if not filename:
@@ -293,79 +304,87 @@ class DefaultDownloader(BaseDownloader):
filename = f"{episode_title}.srt"
else:
filename = f"{episode_title}.{i}.srt"
sub_path = base / filename
# Write subtitle content
with open(sub_path, 'w', encoding='utf-8') as f:
with open(sub_path, "w", encoding="utf-8") as f:
f.write(response.text)
downloaded_subs.append(sub_path)
print(f"Downloaded subtitle: {filename}")
except httpx.HTTPError as e:
logger.error(f"Failed to download subtitle {i}: {e}")
print(f"[red]Failed to download subtitle {i}: {e}[/red]")
return downloaded_subs
def _merge_subtitles(self, params: DownloadParams, video_path: Path, sub_paths: list[Path]) -> Optional[Path]:
def _merge_subtitles(
self, params: DownloadParams, video_path: Path, sub_paths: list[Path]
) -> Optional[Path]:
"""Merge subtitles with video using ffmpeg and return the path to the merged file."""
ffmpeg_executable = shutil.which("ffmpeg")
if not ffmpeg_executable:
raise FastAnimeError("Please install ffmpeg in order to merge subtitles")
merged_filename = video_path.stem + ".mkv"
# Prepare subtitle input arguments
subs_input_args = list(
itertools.chain.from_iterable(
[["-i", str(sub_path)] for sub_path in sub_paths]
)
)
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
temp_output_path = temp_dir / merged_filename
# Construct ffmpeg command
args = [
ffmpeg_executable,
"-hide_banner",
"-i", str(video_path), # Main video input
"-i",
str(video_path), # Main video input
*subs_input_args, # All subtitle inputs
"-c", "copy", # Copy streams without re-encoding
"-map", "0:v", # Map video from first input
"-map", "0:a", # Map audio from first input
"-c",
"copy", # Copy streams without re-encoding
"-map",
"0:v", # Map video from first input
"-map",
"0:a", # Map audio from first input
]
# Map subtitle streams from each subtitle input
for i in range(len(sub_paths)):
args.extend(["-map", f"{i + 1}:s"])
args.append(str(temp_output_path))
print(f"[cyan]Starting subtitle merge for {video_path.name}...[/]")
try:
# Run ffmpeg - use silent flag to control ffmpeg output, not progress
process = subprocess.run(
args,
args,
capture_output=params.silent, # Only suppress ffmpeg output if silent
text=True,
check=True
check=True,
)
final_output_path = video_path.parent / merged_filename
# Handle existing file
if final_output_path.exists():
if not params.prompt or Confirm.ask(
f"File exists ({final_output_path}). Overwrite?",
default=True,
):
print(f"[yellow]Overwriting existing file: {final_output_path}[/]")
print(
f"[yellow]Overwriting existing file: {final_output_path}[/]"
)
final_output_path.unlink()
shutil.move(str(temp_output_path), str(final_output_path))
else:
@@ -373,18 +392,20 @@ class DefaultDownloader(BaseDownloader):
return None
else:
shutil.move(str(temp_output_path), str(final_output_path))
# Clean up original files if requested
if params.clean:
print("[cyan]Cleaning original files...[/]")
video_path.unlink()
for sub_path in sub_paths:
sub_path.unlink()
print(f"[green bold]Subtitles merged successfully.[/] Output: {final_output_path}")
print(
f"[green bold]Subtitles merged successfully.[/] Output: {final_output_path}"
)
return final_output_path
except subprocess.CalledProcessError as e:
error_msg = f"FFmpeg failed: {e.stderr if e.stderr else str(e)}"
logger.error(error_msg)

View File

@@ -30,9 +30,11 @@ class DownloadFactory:
try:
import yt_dlp
from .yt_dlp import YtDLPDownloader
return YtDLPDownloader(config)
except ImportError:
from .default import DefaultDownloader
return DefaultDownloader(config)
else:
raise FastAnimeError("Downloader not implemented")

View File

@@ -8,25 +8,22 @@ from pydantic import BaseModel, Field
class DownloadResult(BaseModel):
"""Result of a download operation."""
success: bool = Field(description="Whether the download was successful")
video_path: Optional[Path] = Field(
default=None,
description="Path to the downloaded video file"
default=None, description="Path to the downloaded video file"
)
subtitle_paths: list[Path] = Field(
default_factory=list,
description="Paths to downloaded subtitle files"
default_factory=list, description="Paths to downloaded subtitle files"
)
merged_path: Optional[Path] = Field(
default=None,
description="Path to the merged video+subtitles file if merge was performed"
description="Path to the merged video+subtitles file if merge was performed",
)
error_message: Optional[str] = Field(
default=None,
description="Error message if download failed"
default=None, description="Error message if download failed"
)
anime_title: str = Field(description="Title of the anime")
episode_title: str = Field(description="Title of the episode")
model_config = {"arbitrary_types_allowed": True}

View File

@@ -4,13 +4,13 @@ import subprocess
import tempfile
import time
from pathlib import Path
from typing import Optional, Dict, Any, Callable, Union
from urllib.parse import urlparse
from typing import Optional, Dict, Any, Callable
from ..exceptions import FastAnimeError, DependencyNotFoundError
try:
import libtorrent as lt
LIBTORRENT_AVAILABLE = True
except ImportError:
LIBTORRENT_AVAILABLE = False
@@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
class TorrentDownloadError(FastAnimeError):
"""Raised when torrent download fails."""
pass
@@ -37,7 +38,7 @@ class TorrentDownloader:
max_download_rate: int = -1, # -1 means unlimited
max_connections: int = 200,
listen_port: int = 6881,
progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None
progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
):
"""
Initialize the torrent downloader.
@@ -65,33 +66,33 @@ class TorrentDownloader:
raise DependencyNotFoundError("libtorrent is not available")
session = lt.session() # type: ignore
# Configure session settings
settings = {
'user_agent': 'FastAnime/1.0',
'listen_interfaces': f'0.0.0.0:{self.listen_port}',
'enable_outgoing_utp': True,
'enable_incoming_utp': True,
'enable_outgoing_tcp': True,
'enable_incoming_tcp': True,
'connections_limit': self.max_connections,
'dht_bootstrap_nodes': 'dht.transmissionbt.com:6881',
"user_agent": "FastAnime/1.0",
"listen_interfaces": f"0.0.0.0:{self.listen_port}",
"enable_outgoing_utp": True,
"enable_incoming_utp": True,
"enable_outgoing_tcp": True,
"enable_incoming_tcp": True,
"connections_limit": self.max_connections,
"dht_bootstrap_nodes": "dht.transmissionbt.com:6881",
}
if self.max_upload_rate > 0:
settings['upload_rate_limit'] = self.max_upload_rate * 1024
settings["upload_rate_limit"] = self.max_upload_rate * 1024
if self.max_download_rate > 0:
settings['download_rate_limit'] = self.max_download_rate * 1024
settings["download_rate_limit"] = self.max_download_rate * 1024
session.apply_settings(settings)
# Start DHT
session.start_dht()
# Add trackers
session.add_dht_router('router.bittorrent.com', 6881)
session.add_dht_router('router.utorrent.com', 6881)
session.add_dht_router("router.bittorrent.com", 6881)
session.add_dht_router("router.utorrent.com", 6881)
logger.info("Libtorrent session configured successfully")
return session
@@ -100,29 +101,29 @@ class TorrentDownloader:
if not LIBTORRENT_AVAILABLE or lt is None:
raise DependencyNotFoundError("libtorrent is not available")
if torrent_source.startswith('magnet:'):
if torrent_source.startswith("magnet:"):
# Parse magnet link
return lt.parse_magnet_uri(torrent_source) # type: ignore
elif torrent_source.startswith(('http://', 'https://')):
elif torrent_source.startswith(("http://", "https://")):
# Download torrent file
import urllib.request
with tempfile.NamedTemporaryFile(suffix='.torrent', delete=False) as tmp_file:
with tempfile.NamedTemporaryFile(
suffix=".torrent", delete=False
) as tmp_file:
urllib.request.urlretrieve(torrent_source, tmp_file.name)
torrent_info = lt.torrent_info(tmp_file.name) # type: ignore
Path(tmp_file.name).unlink() # Clean up temp file
return {'ti': torrent_info}
return {"ti": torrent_info}
else:
# Local torrent file
torrent_path = Path(torrent_source)
if not torrent_path.exists():
raise TorrentDownloadError(f"Torrent file not found: {torrent_source}")
return {'ti': lt.torrent_info(str(torrent_path))} # type: ignore
return {"ti": lt.torrent_info(str(torrent_path))} # type: ignore
def download_with_libtorrent(
self,
torrent_source: str,
timeout: int = 3600,
sequential: bool = False
self, torrent_source: str, timeout: int = 3600, sequential: bool = False
) -> Path:
"""
Download torrent using libtorrent.
@@ -148,48 +149,50 @@ class TorrentDownloader:
try:
self.session = self._setup_libtorrent_session()
torrent_params = self._get_torrent_info(torrent_source)
# Set save path
torrent_params['save_path'] = str(self.download_path)
torrent_params["save_path"] = str(self.download_path)
if sequential and lt is not None:
torrent_params['flags'] = lt.torrent_flags.sequential_download # type: ignore
torrent_params["flags"] = lt.torrent_flags.sequential_download # type: ignore
# Add torrent to session
if self.session is None:
raise TorrentDownloadError("Session is not initialized")
handle = self.session.add_torrent(torrent_params)
logger.info(f"Starting torrent download: {handle.name()}")
# Monitor download progress
start_time = time.time()
last_log_time = start_time
while not handle.is_seed():
current_time = time.time()
# Check timeout
if current_time - start_time > timeout:
raise TorrentDownloadError(f"Download timeout after {timeout} seconds")
raise TorrentDownloadError(
f"Download timeout after {timeout} seconds"
)
status = handle.status()
# Prepare progress info
progress_info = {
'name': handle.name(),
'progress': status.progress * 100,
'download_rate': status.download_rate / 1024, # KB/s
'upload_rate': status.upload_rate / 1024, # KB/s
'num_peers': status.num_peers,
'total_size': status.total_wanted,
'downloaded': status.total_wanted_done,
'state': str(status.state),
"name": handle.name(),
"progress": status.progress * 100,
"download_rate": status.download_rate / 1024, # KB/s
"upload_rate": status.upload_rate / 1024, # KB/s
"num_peers": status.num_peers,
"total_size": status.total_wanted,
"downloaded": status.total_wanted_done,
"state": str(status.state),
}
# Call progress callback if provided
if self.progress_callback:
self.progress_callback(progress_info)
# Log progress periodically (every 10 seconds)
if current_time - last_log_time >= 10:
logger.info(
@@ -198,23 +201,23 @@ class TorrentDownloader:
f"- {progress_info['num_peers']} peers"
)
last_log_time = current_time
# Check for errors
if status.error:
raise TorrentDownloadError(f"Torrent error: {status.error}")
time.sleep(1)
# Download completed
download_path = self.download_path / handle.name()
logger.info(f"Torrent download completed: {download_path}")
# Remove torrent from session
if self.session is not None:
self.session.remove_torrent(handle)
return download_path
except Exception as e:
if isinstance(e, (TorrentDownloadError, DependencyNotFoundError)):
raise
@@ -243,40 +246,52 @@ class TorrentDownloader:
raise DependencyNotFoundError(
"webtorrent CLI is not available. Please install it: npm install -g webtorrent-cli"
)
try:
cmd = [webtorrent_cli, "download", torrent_source, "--out", str(self.download_path)]
cmd = [
webtorrent_cli,
"download",
torrent_source,
"--out",
str(self.download_path),
]
logger.info(f"Running webtorrent command: {' '.join(cmd)}")
result = subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=3600)
result = subprocess.run(
cmd, check=True, capture_output=True, text=True, timeout=3600
)
# Try to determine the download path from the output
# This is a best-effort approach since webtorrent output format may vary
output_lines = result.stdout.split('\n')
output_lines = result.stdout.split("\n")
for line in output_lines:
if 'Downloaded' in line and 'to' in line:
if "Downloaded" in line and "to" in line:
# Extract path from output
parts = line.split('to')
parts = line.split("to")
if len(parts) > 1:
path_str = parts[-1].strip().strip('"\'') # Remove quotes
path_str = parts[-1].strip().strip("\"'") # Remove quotes
download_path = Path(path_str)
if download_path.exists():
logger.info(f"Successfully downloaded to: {download_path}")
return download_path
# If we can't parse the output, scan the download directory for new files
logger.warning("Could not parse webtorrent output, scanning download directory")
logger.warning(
"Could not parse webtorrent output, scanning download directory"
)
download_candidates = list(self.download_path.iterdir())
if download_candidates:
# Return the most recently modified item
newest_path = max(download_candidates, key=lambda p: p.stat().st_mtime)
logger.info(f"Found downloaded content: {newest_path}")
return newest_path
# Fallback: return the download directory
logger.warning(f"No specific download found, returning download directory: {self.download_path}")
logger.warning(
f"No specific download found, returning download directory: {self.download_path}"
)
return self.download_path
except subprocess.CalledProcessError as e:
error_msg = e.stderr or e.stdout or "Unknown error"
raise TorrentDownloadError(
@@ -287,13 +302,12 @@ class TorrentDownloader:
f"webtorrent CLI timeout after {e.timeout} seconds"
) from e
except Exception as e:
raise TorrentDownloadError(f"Failed to download with webtorrent: {str(e)}") from e
raise TorrentDownloadError(
f"Failed to download with webtorrent: {str(e)}"
) from e
def download(
self,
torrent_source: str,
prefer_libtorrent: bool = True,
**kwargs
self, torrent_source: str, prefer_libtorrent: bool = True, **kwargs
) -> Path:
"""
Download torrent using the best available method.
@@ -310,24 +324,28 @@ class TorrentDownloader:
TorrentDownloadError: If all download methods fail
"""
methods = []
if prefer_libtorrent and LIBTORRENT_AVAILABLE:
methods.extend([
('libtorrent', self.download_with_libtorrent),
('webtorrent-cli', self.download_with_webtorrent_cli)
])
methods.extend(
[
("libtorrent", self.download_with_libtorrent),
("webtorrent-cli", self.download_with_webtorrent_cli),
]
)
else:
methods.extend([
('webtorrent-cli', self.download_with_webtorrent_cli),
('libtorrent', self.download_with_libtorrent)
])
methods.extend(
[
("webtorrent-cli", self.download_with_webtorrent_cli),
("libtorrent", self.download_with_libtorrent),
]
)
last_exception = None
for method_name, method_func in methods:
try:
logger.info(f"Attempting download with {method_name}")
if method_name == 'libtorrent':
if method_name == "libtorrent":
return method_func(torrent_source, **kwargs)
else:
return method_func(torrent_source)
@@ -339,7 +357,7 @@ class TorrentDownloader:
logger.error(f"{method_name} failed: {e}")
last_exception = e
continue
# All methods failed
raise TorrentDownloadError(
f"All torrent download methods failed. Last error: {last_exception}"

View File

@@ -28,7 +28,7 @@ class YtDLPDownloader(BaseDownloader):
video_path = None
sub_paths = []
merged_path = None
if TORRENT_REGEX.match(params.url):
from .torrents import download_torrent_with_webtorrent_cli
@@ -38,24 +38,26 @@ class YtDLPDownloader(BaseDownloader):
dest_dir.mkdir(parents=True, exist_ok=True)
video_path = dest_dir / episode_title
video_path = download_torrent_with_webtorrent_cli(video_path, params.url)
video_path = download_torrent_with_webtorrent_cli(
video_path, params.url
)
else:
video_path = self._download_video(params)
if params.subtitles:
sub_paths = self._download_subs(params)
if params.merge:
merged_path = self._merge_subtitles(params, video_path, sub_paths)
return DownloadResult(
success=True,
video_path=video_path,
subtitle_paths=sub_paths,
merged_path=merged_path,
anime_title=params.anime_title,
episode_title=params.episode_title
episode_title=params.episode_title,
)
except KeyboardInterrupt:
print()
print("Aborted!")
@@ -63,7 +65,7 @@ class YtDLPDownloader(BaseDownloader):
success=False,
error_message="Download aborted by user",
anime_title=params.anime_title,
episode_title=params.episode_title
episode_title=params.episode_title,
)
except Exception as e:
logger.error(f"Download failed: {e}")
@@ -71,7 +73,7 @@ class YtDLPDownloader(BaseDownloader):
success=False,
error_message=str(e),
anime_title=params.anime_title,
episode_title=params.episode_title
episode_title=params.episode_title,
)
def _download_video(self, params: DownloadParams) -> Path:
@@ -167,7 +169,9 @@ class YtDLPDownloader(BaseDownloader):
downloaded_subs.append(sub_path)
return downloaded_subs
def _merge_subtitles(self, params, video_path: Path, sub_paths: list[Path]) -> Path | None:
def _merge_subtitles(
self, params, video_path: Path, sub_paths: list[Path]
) -> Path | None:
"""Merge subtitles with video and return the path to the merged file."""
self.FFMPEG_EXECUTABLE = shutil.which("ffmpeg")
if not self.FFMPEG_EXECUTABLE:
@@ -243,7 +247,7 @@ class YtDLPDownloader(BaseDownloader):
f"[green bold]Subtitles merged successfully.[/] Output file: {final_output_path}"
)
return final_output_path
except Exception as e:
print(f"[red bold]An unexpected error[/] occurred: {e}")
return None

View File

@@ -0,0 +1,7 @@
"""
Core utilities for FastAnime application.
This module provides various utility classes and functions used throughout
the FastAnime application, including concurrency management, file operations,
and other common functionality.
"""

View File

@@ -0,0 +1,389 @@
"""
Concurrency utilities for managing background tasks and thread lifecycle.
This module provides abstract base classes and concrete implementations for managing
background workers with proper lifecycle control, cancellation support, and resource cleanup.
"""
import logging
import threading
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Protocol, TypeVar
from weakref import WeakSet
logger = logging.getLogger(__name__)
T = TypeVar("T")
class Cancellable(Protocol):
"""Protocol for objects that can be cancelled."""
def cancel(self) -> bool:
"""Cancel the operation. Returns True if cancellation was successful."""
...
def cancelled(self) -> bool:
"""Return True if the operation was cancelled."""
...
class WorkerTask:
"""Represents a single task that can be executed by a worker."""
def __init__(self, func: Callable[..., Any], *args, **kwargs):
"""
Initialize a worker task.
Args:
func: The function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
"""
self.func = func
self.args = args
self.kwargs = kwargs
self._cancelled = threading.Event()
self._completed = threading.Event()
self._exception: Optional[Exception] = None
self._result: Any = None
def execute(self) -> Any:
"""Execute the task if not cancelled."""
if self._cancelled.is_set():
return None
try:
self._result = self.func(*self.args, **self.kwargs)
return self._result
except Exception as e:
self._exception = e
logger.error(f"Task execution failed: {e}")
raise
finally:
self._completed.set()
def cancel(self) -> bool:
"""Cancel the task."""
if self._completed.is_set():
return False
self._cancelled.set()
return True
def cancelled(self) -> bool:
"""Check if the task was cancelled."""
return self._cancelled.is_set()
def completed(self) -> bool:
"""Check if the task completed."""
return self._completed.is_set()
@property
def exception(self) -> Optional[Exception]:
"""Get the exception if one occurred during execution."""
return self._exception
@property
def result(self) -> Any:
"""Get the result of the task execution."""
return self._result
class BackgroundWorker(ABC):
"""
Abstract base class for background workers that manage concurrent tasks.
Provides lifecycle management, cancellation support, and proper resource cleanup.
"""
def __init__(self, max_workers: int = 5, name: Optional[str] = None):
"""
Initialize the background worker.
Args:
max_workers: Maximum number of concurrent worker threads
name: Optional name for the worker (used in logging)
"""
self.max_workers = max_workers
self.name = name or self.__class__.__name__
self._executor: Optional[ThreadPoolExecutor] = None
self._futures: WeakSet[Future] = WeakSet()
self._tasks: List[WorkerTask] = []
self._shutdown_event = threading.Event()
self._lock = threading.RLock()
self._started = False
def start(self) -> None:
"""Start the background worker."""
with self._lock:
if self._started:
logger.warning(f"Worker {self.name} is already started")
return
self._executor = ThreadPoolExecutor(
max_workers=self.max_workers, thread_name_prefix=f"{self.name}-worker"
)
self._started = True
logger.debug(f"Started background worker: {self.name}")
def submit_task(self, task: WorkerTask) -> Future:
"""
Submit a task for background execution.
Args:
task: The task to execute
Returns:
Future representing the task execution
Raises:
RuntimeError: If the worker is not started or is shutting down
"""
with self._lock:
if not self._started or self._shutdown_event.is_set():
raise RuntimeError(f"Worker {self.name} is not available")
if self._executor is None:
raise RuntimeError(f"Worker {self.name} executor is not initialized")
self._tasks.append(task)
future = self._executor.submit(task.execute)
self._futures.add(future)
logger.debug(f"Submitted task to worker {self.name}")
return future
def submit_function(self, func: Callable[..., Any], *args, **kwargs) -> Future:
"""
Submit a function for background execution.
Args:
func: The function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Future representing the task execution
"""
task = WorkerTask(func, *args, **kwargs)
return self.submit_task(task)
def cancel_all_tasks(self) -> int:
"""
Cancel all pending and running tasks.
Returns:
Number of tasks that were successfully cancelled
"""
cancelled_count = 0
with self._lock:
# Cancel all tasks
for task in self._tasks:
if task.cancel():
cancelled_count += 1
# Cancel all futures
for future in list(self._futures):
if future.cancel():
cancelled_count += 1
logger.debug(f"Cancelled {cancelled_count} tasks in worker {self.name}")
return cancelled_count
def shutdown(self, wait: bool = True, timeout: Optional[float] = 30.0) -> None:
"""
Shutdown the background worker.
Args:
wait: Whether to wait for running tasks to complete
timeout: Maximum time to wait for shutdown (ignored if wait=False)
"""
with self._lock:
if not self._started:
return
self._shutdown_event.set()
self._started = False
if self._executor is None:
return
logger.debug(f"Shutting down worker {self.name}")
if not wait:
# Cancel all tasks and shutdown immediately
self.cancel_all_tasks()
self._executor.shutdown(wait=False, cancel_futures=True)
else:
# Wait for tasks to complete with timeout
try:
self._executor.shutdown(wait=True, timeout=timeout)
except TimeoutError:
logger.warning(
f"Worker {self.name} shutdown timed out, forcing cancellation"
)
self.cancel_all_tasks()
self._executor.shutdown(wait=False, cancel_futures=True)
self._executor = None
logger.debug(f"Worker {self.name} shutdown complete")
def is_running(self) -> bool:
"""Check if the worker is currently running."""
return self._started and not self._shutdown_event.is_set()
def get_active_task_count(self) -> int:
"""Get the number of active (non-completed) tasks."""
with self._lock:
return sum(1 for task in self._tasks if not task.completed())
@abstractmethod
def _on_task_completed(self, task: WorkerTask, future: Future) -> None:
"""
Hook called when a task completes (successfully or with error).
Args:
task: The completed task
future: The future representing the task execution
"""
pass
def __enter__(self):
"""Context manager entry."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit with automatic cleanup."""
self.shutdown(wait=True)
class ManagedBackgroundWorker(BackgroundWorker):
"""
Concrete implementation of BackgroundWorker with task completion tracking.
This worker provides additional monitoring and logging of task completion.
"""
def __init__(self, max_workers: int = 5, name: Optional[str] = None):
super().__init__(max_workers, name)
self._completed_tasks: List[WorkerTask] = []
self._failed_tasks: List[WorkerTask] = []
def _on_task_completed(self, task: WorkerTask, future: Future) -> None:
"""Track completed tasks and log results."""
try:
if future.exception():
self._failed_tasks.append(task)
logger.error(f"Task failed in worker {self.name}: {future.exception()}")
else:
self._completed_tasks.append(task)
logger.debug(f"Task completed successfully in worker {self.name}")
except Exception as e:
logger.error(f"Error in task completion handler: {e}")
def get_completion_stats(self) -> Dict[str, int]:
"""Get statistics about task completion."""
with self._lock:
return {
"total_tasks": len(self._tasks),
"completed_tasks": len(self._completed_tasks),
"failed_tasks": len(self._failed_tasks),
"active_tasks": self.get_active_task_count(),
}
class ThreadManager:
"""
Manages multiple background workers and provides centralized control.
This class acts as a registry for all background workers in the application,
allowing for coordinated shutdown and monitoring.
"""
def __init__(self):
self._workers: Dict[str, BackgroundWorker] = {}
self._lock = threading.RLock()
def register_worker(self, name: str, worker: BackgroundWorker) -> None:
"""
Register a background worker.
Args:
name: Unique name for the worker
worker: The worker instance to register
"""
with self._lock:
if name in self._workers:
raise ValueError(f"Worker with name '{name}' already registered")
self._workers[name] = worker
logger.debug(f"Registered worker: {name}")
def get_worker(self, name: str) -> Optional[BackgroundWorker]:
"""Get a registered worker by name."""
with self._lock:
return self._workers.get(name)
def shutdown_worker(
self, name: str, wait: bool = True, timeout: Optional[float] = 30.0
) -> bool:
"""
Shutdown a specific worker.
Args:
name: Name of the worker to shutdown
wait: Whether to wait for completion
timeout: Shutdown timeout
Returns:
True if worker was found and shutdown, False otherwise
"""
with self._lock:
worker = self._workers.get(name)
if worker:
worker.shutdown(wait=wait, timeout=timeout)
del self._workers[name]
logger.debug(f"Shutdown worker: {name}")
return True
return False
def shutdown_all(self, wait: bool = True, timeout: Optional[float] = 30.0) -> None:
"""Shutdown all registered workers."""
with self._lock:
workers_to_shutdown = list(self._workers.items())
for name, worker in workers_to_shutdown:
try:
worker.shutdown(wait=wait, timeout=timeout)
logger.debug(f"Shutdown worker: {name}")
except Exception as e:
logger.error(f"Error shutting down worker {name}: {e}")
with self._lock:
self._workers.clear()
def get_all_workers(self) -> Dict[str, BackgroundWorker]:
"""Get a copy of all registered workers."""
with self._lock:
return self._workers.copy()
def get_status(self) -> Dict[str, Dict[str, Any]]:
"""Get status information for all workers."""
status = {}
with self._lock:
for name, worker in self._workers.items():
status[name] = {
"running": worker.is_running(),
"active_tasks": worker.get_active_task_count(),
}
if isinstance(worker, ManagedBackgroundWorker):
status[name].update(worker.get_completion_stats())
return status
# Global thread manager instance
thread_manager = ThreadManager()

View File

@@ -7,9 +7,11 @@ from typing import IO, Any, Union
logger = logging.getLogger(__name__)
class NO_DEFAULT:
pass
def sanitize_filename(s, restricted=False, is_id=NO_DEFAULT):
"""Sanitizes a string so it could be used as part of a filename.
@param restricted Use a stricter subset of allowed characters
@@ -19,58 +21,85 @@ def sanitize_filename(s, restricted=False, is_id=NO_DEFAULT):
import itertools
import unicodedata
import re
ACCENT_CHARS = dict(zip('ÂÃÄÀÁÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖŐØŒÙÚÛÜŰÝÞßàáâãäåæçèéêëìíîïðñòóôõöőøœùúûüűýþÿ',
itertools.chain('AAAAAA', ['AE'], 'CEEEEIIIIDNOOOOOOO', ['OE'], 'UUUUUY', ['TH', 'ss'],
'aaaaaa', ['ae'], 'ceeeeiiiionooooooo', ['oe'], 'uuuuuy', ['th'], 'y')))
if s == '':
return ''
ACCENT_CHARS = dict(
zip(
"ÂÃÄÀÁÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖŐØŒÙÚÛÜŰÝÞßàáâãäåæçèéêëìíîïðñòóôõöőøœùúûüűýþÿ",
itertools.chain(
"AAAAAA",
["AE"],
"CEEEEIIIIDNOOOOOOO",
["OE"],
"UUUUUY",
["TH", "ss"],
"aaaaaa",
["ae"],
"ceeeeiiiionooooooo",
["oe"],
"uuuuuy",
["th"],
"y",
),
)
)
if s == "":
return ""
def replace_insane(char):
if restricted and char in ACCENT_CHARS:
return ACCENT_CHARS[char]
elif not restricted and char == '\n':
return '\0 '
elif not restricted and char == "\n":
return "\0 "
elif is_id is NO_DEFAULT and not restricted and char in '"*:<>?|/\\':
# Replace with their full-width unicode counterparts
return {'/': '\u29F8', '\\': '\u29f9'}.get(char, chr(ord(char) + 0xfee0))
elif char == '?' or ord(char) < 32 or ord(char) == 127:
return ''
return {"/": "\u29f8", "\\": "\u29f9"}.get(char, chr(ord(char) + 0xFEE0))
elif char == "?" or ord(char) < 32 or ord(char) == 127:
return ""
elif char == '"':
return '' if restricted else '\''
elif char == ':':
return '\0_\0-' if restricted else '\0 \0-'
elif char in '\\/|*<>':
return '\0_'
if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace() or ord(char) > 127):
return '' if unicodedata.category(char)[0] in 'CM' else '\0_'
return "" if restricted else "'"
elif char == ":":
return "\0_\0-" if restricted else "\0 \0-"
elif char in "\\/|*<>":
return "\0_"
if restricted and (
char in "!&'()[]{}$;`^,#" or char.isspace() or ord(char) > 127
):
return "" if unicodedata.category(char)[0] in "CM" else "\0_"
return char
# Replace look-alike Unicode glyphs
if restricted and (is_id is NO_DEFAULT or not is_id):
s = unicodedata.normalize('NFKC', s)
s = re.sub(r'[0-9]+(?::[0-9]+)+', lambda m: m.group(0).replace(':', '_'), s) # Handle timestamps
result = ''.join(map(replace_insane, s))
s = unicodedata.normalize("NFKC", s)
s = re.sub(
r"[0-9]+(?::[0-9]+)+", lambda m: m.group(0).replace(":", "_"), s
) # Handle timestamps
result = "".join(map(replace_insane, s))
if is_id is NO_DEFAULT:
result = re.sub(r'(\0.)(?:(?=\1)..)+', r'\1', result) # Remove repeated substitute chars
STRIP_RE = r'(?:\0.|[ _-])*'
result = re.sub(f'^\0.{STRIP_RE}|{STRIP_RE}\0.$', '', result) # Remove substitute chars from start/end
result = result.replace('\0', '') or '_'
result = re.sub(
r"(\0.)(?:(?=\1)..)+", r"\1", result
) # Remove repeated substitute chars
STRIP_RE = r"(?:\0.|[ _-])*"
result = re.sub(
f"^\0.{STRIP_RE}|{STRIP_RE}\0.$", "", result
) # Remove substitute chars from start/end
result = result.replace("\0", "") or "_"
if not is_id:
while '__' in result:
result = result.replace('__', '_')
result = result.strip('_')
while "__" in result:
result = result.replace("__", "_")
result = result.strip("_")
# Common case of "Foreign band name - English song title"
if restricted and result.startswith('-_'):
if restricted and result.startswith("-_"):
result = result[2:]
if result.startswith('-'):
result = '_' + result[len('-'):]
result = result.lstrip('.')
if result.startswith("-"):
result = "_" + result[len("-") :]
result = result.lstrip(".")
if not result:
result = '_'
result = "_"
return result
def get_file_modification_time(filepath: Path) -> float:
"""
Returns the modification time of a file as a Unix timestamp.

View File

@@ -6,37 +6,37 @@ otherwise falls back to a pure Python implementation with the same API.
Usage:
Basic usage with the convenience functions:
>>> from fastanime.core.utils.fuzzy import fuzz
>>> fuzz.ratio("hello world", "hello")
62
>>> fuzz.partial_ratio("hello world", "hello")
>>> fuzz.partial_ratio("hello world", "hello")
100
Using the FuzzyMatcher class directly:
>>> from fastanime.core.utils.fuzzy import FuzzyMatcher
>>> matcher = FuzzyMatcher()
>>> matcher.backend
'thefuzz' # or 'pure_python' if thefuzz is not available
>>> matcher.token_sort_ratio("fuzzy wuzzy", "wuzzy fuzzy")
100
For drop-in replacement of thefuzz.fuzz:
>>> from fastanime.core.utils.fuzzy import ratio, partial_ratio
>>> ratio("test", "best")
75
"""
import logging
from typing import Any, Optional, Union
logger = logging.getLogger(__name__)
# Try to import thefuzz, fall back to pure Python implementation
try:
from thefuzz import fuzz as _fuzz_impl
THEFUZZ_AVAILABLE = True
logger.debug("Using thefuzz for fuzzy matching")
except ImportError:
@@ -48,29 +48,29 @@ except ImportError:
class _PurePythonFuzz:
"""
Pure Python implementation of fuzzy string matching algorithms.
This provides the same API as thefuzz.fuzz but with pure Python implementations
of the core algorithms.
"""
@staticmethod
def _levenshtein_distance(s1: str, s2: str) -> int:
"""
Calculate the Levenshtein distance between two strings.
Args:
s1: First string
s2: Second string
Returns:
The Levenshtein distance as an integer
"""
if len(s1) < len(s2):
return _PurePythonFuzz._levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
current_row = [i + 1]
@@ -81,55 +81,55 @@ class _PurePythonFuzz:
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
@staticmethod
def _longest_common_subsequence(s1: str, s2: str) -> int:
"""
Calculate the length of the longest common subsequence.
Args:
s1: First string
s2: Second string
Returns:
Length of the longest common subsequence
"""
m, n = len(s1), len(s2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if s1[i - 1] == s2[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
return dp[m][n]
@staticmethod
def _normalize_string(s: str) -> str:
"""
Normalize a string for comparison by converting to lowercase and stripping whitespace.
Args:
s: String to normalize
Returns:
Normalized string
"""
return s.lower().strip()
@staticmethod
def ratio(s1: str, s2: str) -> int:
"""
Calculate the similarity ratio between two strings using Levenshtein distance.
Args:
s1: First string
s2: Second string
Returns:
Similarity ratio as an integer from 0 to 100
"""
@@ -137,185 +137,185 @@ class _PurePythonFuzz:
return 100
if not s1 or not s2:
return 0
distance = _PurePythonFuzz._levenshtein_distance(s1, s2)
max_len = max(len(s1), len(s2))
if max_len == 0:
return 100
similarity = (max_len - distance) / max_len
return int(similarity * 100)
@staticmethod
def partial_ratio(s1: str, s2: str) -> int:
"""
Calculate the partial similarity ratio between two strings.
This finds the best matching substring and calculates the ratio for that.
Args:
s1: First string
s2: Second string
Returns:
Partial similarity ratio as an integer from 0 to 100
"""
if not s1 or not s2:
return 0
if len(s1) <= len(s2):
shorter, longer = s1, s2
else:
shorter, longer = s2, s1
best_ratio = 0
for i in range(len(longer) - len(shorter) + 1):
substring = longer[i:i + len(shorter)]
substring = longer[i : i + len(shorter)]
ratio = _PurePythonFuzz.ratio(shorter, substring)
best_ratio = max(best_ratio, ratio)
return best_ratio
@staticmethod
def token_sort_ratio(s1: str, s2: str) -> int:
"""
Calculate similarity after sorting tokens in both strings.
Args:
s1: First string
s2: Second string
Returns:
Token sort ratio as an integer from 0 to 100
"""
if not s1 or not s2:
return 0
# Normalize and split into tokens
tokens1 = sorted(_PurePythonFuzz._normalize_string(s1).split())
tokens2 = sorted(_PurePythonFuzz._normalize_string(s2).split())
# Rejoin sorted tokens
sorted_s1 = ' '.join(tokens1)
sorted_s2 = ' '.join(tokens2)
sorted_s1 = " ".join(tokens1)
sorted_s2 = " ".join(tokens2)
return _PurePythonFuzz.ratio(sorted_s1, sorted_s2)
@staticmethod
def token_set_ratio(s1: str, s2: str) -> int:
"""
Calculate similarity using set operations on tokens.
Args:
s1: First string
s2: Second string
Returns:
Token set ratio as an integer from 0 to 100
"""
if not s1 or not s2:
return 0
# Normalize and split into tokens
tokens1 = set(_PurePythonFuzz._normalize_string(s1).split())
tokens2 = set(_PurePythonFuzz._normalize_string(s2).split())
# Find intersection and differences
intersection = tokens1 & tokens2
diff1 = tokens1 - tokens2
diff2 = tokens2 - tokens1
# Create sorted strings for comparison
sorted_intersection = ' '.join(sorted(intersection))
sorted_diff1 = ' '.join(sorted(diff1))
sorted_diff2 = ' '.join(sorted(diff2))
sorted_intersection = " ".join(sorted(intersection))
sorted_diff1 = " ".join(sorted(diff1))
sorted_diff2 = " ".join(sorted(diff2))
# Combine strings for comparison
combined1 = f"{sorted_intersection} {sorted_diff1}".strip()
combined2 = f"{sorted_intersection} {sorted_diff2}".strip()
if not combined1 and not combined2:
return 100
if not combined1 or not combined2:
return 0
return _PurePythonFuzz.ratio(combined1, combined2)
@staticmethod
def partial_token_sort_ratio(s1: str, s2: str) -> int:
"""
Calculate partial similarity after sorting tokens.
Args:
s1: First string
s2: Second string
Returns:
Partial token sort ratio as an integer from 0 to 100
"""
if not s1 or not s2:
return 0
# Normalize and split into tokens
tokens1 = sorted(_PurePythonFuzz._normalize_string(s1).split())
tokens2 = sorted(_PurePythonFuzz._normalize_string(s2).split())
# Rejoin sorted tokens
sorted_s1 = ' '.join(tokens1)
sorted_s2 = ' '.join(tokens2)
sorted_s1 = " ".join(tokens1)
sorted_s2 = " ".join(tokens2)
return _PurePythonFuzz.partial_ratio(sorted_s1, sorted_s2)
@staticmethod
def partial_token_set_ratio(s1: str, s2: str) -> int:
"""
Calculate partial similarity using set operations on tokens.
Args:
s1: First string
s2: Second string
Returns:
Partial token set ratio as an integer from 0 to 100
"""
if not s1 or not s2:
return 0
# Normalize and split into tokens
tokens1 = set(_PurePythonFuzz._normalize_string(s1).split())
tokens2 = set(_PurePythonFuzz._normalize_string(s2).split())
# Find intersection and differences
intersection = tokens1 & tokens2
diff1 = tokens1 - tokens2
diff2 = tokens2 - tokens1
# Create sorted strings for comparison
sorted_intersection = ' '.join(sorted(intersection))
sorted_diff1 = ' '.join(sorted(diff1))
sorted_diff2 = ' '.join(sorted(diff2))
sorted_intersection = " ".join(sorted(intersection))
sorted_diff1 = " ".join(sorted(diff1))
sorted_diff2 = " ".join(sorted(diff2))
# Combine strings for comparison
combined1 = f"{sorted_intersection} {sorted_diff1}".strip()
combined2 = f"{sorted_intersection} {sorted_diff2}".strip()
if not combined1 and not combined2:
return 100
if not combined1 or not combined2:
return 0
return _PurePythonFuzz.partial_ratio(combined1, combined2)
class FuzzyMatcher:
"""
Fuzzy string matching class with the same API as thefuzz.fuzz.
This class automatically uses thefuzz if available, otherwise falls back
to a pure Python implementation.
"""
def __init__(self):
"""Initialize the fuzzy matcher with the appropriate backend."""
if THEFUZZ_AVAILABLE and _fuzz_impl is not None:
@@ -324,22 +324,22 @@ class FuzzyMatcher:
else:
self._impl = _PurePythonFuzz
self._backend = "pure_python"
logger.debug(f"FuzzyMatcher initialized with backend: {self._backend}")
@property
def backend(self) -> str:
"""Get the name of the backend being used."""
return self._backend
def ratio(self, s1: str, s2: str) -> int:
"""
Calculate the similarity ratio between two strings.
Args:
s1: First string
s2: Second string
Returns:
Similarity ratio as an integer from 0 to 100
"""
@@ -348,15 +348,15 @@ class FuzzyMatcher:
except Exception as e:
logger.warning(f"Error in ratio calculation: {e}")
return 0
def partial_ratio(self, s1: str, s2: str) -> int:
"""
Calculate the partial similarity ratio between two strings.
Args:
s1: First string
s2: Second string
Returns:
Partial similarity ratio as an integer from 0 to 100
"""
@@ -365,15 +365,15 @@ class FuzzyMatcher:
except Exception as e:
logger.warning(f"Error in partial_ratio calculation: {e}")
return 0
def token_sort_ratio(self, s1: str, s2: str) -> int:
"""
Calculate similarity after sorting tokens in both strings.
Args:
s1: First string
s2: Second string
Returns:
Token sort ratio as an integer from 0 to 100
"""
@@ -382,15 +382,15 @@ class FuzzyMatcher:
except Exception as e:
logger.warning(f"Error in token_sort_ratio calculation: {e}")
return 0
def token_set_ratio(self, s1: str, s2: str) -> int:
"""
Calculate similarity using set operations on tokens.
Args:
s1: First string
s2: Second string
Returns:
Token set ratio as an integer from 0 to 100
"""
@@ -399,15 +399,15 @@ class FuzzyMatcher:
except Exception as e:
logger.warning(f"Error in token_set_ratio calculation: {e}")
return 0
def partial_token_sort_ratio(self, s1: str, s2: str) -> int:
"""
Calculate partial similarity after sorting tokens.
Args:
s1: First string
s2: Second string
Returns:
Partial token sort ratio as an integer from 0 to 100
"""
@@ -416,15 +416,15 @@ class FuzzyMatcher:
except Exception as e:
logger.warning(f"Error in partial_token_sort_ratio calculation: {e}")
return 0
def partial_token_set_ratio(self, s1: str, s2: str) -> int:
"""
Calculate partial similarity using set operations on tokens.
Args:
s1: First string
s2: Second string
Returns:
Partial token set ratio as an integer from 0 to 100
"""
@@ -433,15 +433,15 @@ class FuzzyMatcher:
except Exception as e:
logger.warning(f"Error in partial_token_set_ratio calculation: {e}")
return 0
def best_ratio(self, s1: str, s2: str) -> int:
"""
Get the best ratio from all available methods.
Args:
s1: First string
s2: Second string
Returns:
Best similarity ratio as an integer from 0 to 100
"""
@@ -468,13 +468,13 @@ partial_token_sort_ratio = fuzz.partial_token_sort_ratio
partial_token_set_ratio = fuzz.partial_token_set_ratio
__all__ = [
'FuzzyMatcher',
'fuzz',
'ratio',
'partial_ratio',
'token_sort_ratio',
'token_set_ratio',
'partial_token_sort_ratio',
'partial_token_set_ratio',
'THEFUZZ_AVAILABLE',
"FuzzyMatcher",
"fuzz",
"ratio",
"partial_ratio",
"token_sort_ratio",
"token_set_ratio",
"partial_token_sort_ratio",
"partial_token_set_ratio",
"THEFUZZ_AVAILABLE",
]

View File

@@ -26,15 +26,15 @@ Example Usage:
... provider_title_to_media_api_title,
... media_api_title_to_provider_title
... )
# Convert provider title to media API title
>>> provider_title_to_media_api_title("1P", "allanime")
'one piece'
# Convert media API title to provider title
>>> media_api_title_to_provider_title("one piece", "allanime")
'1P'
# Check available providers
>>> get_available_providers()
['allanime', 'hianime', 'animepahe']
@@ -44,7 +44,6 @@ Author: FastAnime Contributors
import json
import logging
from pathlib import Path
from typing import Dict, Optional
from ..constants import ASSETS_DIR
@@ -58,23 +57,23 @@ _normalizer_cache: Optional[Dict[str, Dict[str, str]]] = None
def _load_normalizer_data() -> Dict[str, Dict[str, str]]:
"""
Load the normalizer.json file and cache it.
Returns:
Dictionary containing provider mappings from normalizer.json
Raises:
FileNotFoundError: If normalizer.json is not found
json.JSONDecodeError: If normalizer.json is malformed
"""
global _normalizer_cache
if _normalizer_cache is not None:
return _normalizer_cache
normalizer_path = ASSETS_DIR / "normalizer.json"
try:
with open(normalizer_path, 'r', encoding='utf-8') as f:
with open(normalizer_path, "r", encoding="utf-8") as f:
_normalizer_cache = json.load(f)
logger.debug("Loaded normalizer data from %s", normalizer_path)
# Type checker now knows _normalizer_cache is not None
@@ -88,23 +87,20 @@ def _load_normalizer_data() -> Dict[str, Dict[str, str]]:
raise
def provider_title_to_media_api_title(
provider_title: str,
provider_name: str
) -> str:
def provider_title_to_media_api_title(provider_title: str, provider_name: str) -> str:
"""
Convert a provider title to its equivalent media API title.
This function takes a title from a specific provider (e.g., "1P" from allanime)
and converts it to the standard media API title (e.g., "one piece").
Args:
provider_title: The title as it appears on the provider
provider_name: The name of the provider (e.g., "allanime", "hianime", "animepahe")
Returns:
The normalized media API title, or the original title if no mapping exists
Example:
>>> provider_title_to_media_api_title("1P", "allanime")
"one piece"
@@ -115,48 +111,47 @@ def provider_title_to_media_api_title(
"""
try:
normalizer_data = _load_normalizer_data()
# Check if the provider exists in the normalizer data
if provider_name not in normalizer_data:
logger.debug("Provider '%s' not found in normalizer data", provider_name)
return provider_title
provider_mappings = normalizer_data[provider_name]
# Return the mapped title if it exists, otherwise return the original
normalized_title = provider_mappings.get(provider_title, provider_title)
if normalized_title != provider_title:
logger.debug(
"Normalized provider title: '%s' -> '%s' (provider: %s)",
provider_title, normalized_title, provider_name
provider_title,
normalized_title,
provider_name,
)
return normalized_title
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.warning("Failed to load normalizer data: %s", e)
return provider_title
def media_api_title_to_provider_title(
media_api_title: str,
provider_name: str
) -> str:
def media_api_title_to_provider_title(media_api_title: str, provider_name: str) -> str:
"""
Convert a media API title to its equivalent provider title.
This function takes a standard media API title and converts it to the title
used by a specific provider. This is the reverse operation of
used by a specific provider. This is the reverse operation of
provider_title_to_media_api_title().
Args:
media_api_title: The title as it appears in the media API (e.g., AniList)
provider_name: The name of the provider (e.g., "allanime", "hianime", "animepahe")
Returns:
The provider-specific title, or the original title if no mapping exists
Example:
>>> media_api_title_to_provider_title("one piece", "allanime")
"1P"
@@ -167,53 +162,53 @@ def media_api_title_to_provider_title(
"""
try:
normalizer_data = _load_normalizer_data()
# Check if the provider exists in the normalizer data
if provider_name not in normalizer_data:
logger.debug("Provider '%s' not found in normalizer data", provider_name)
return media_api_title
provider_mappings = normalizer_data[provider_name]
# Create a reverse mapping (media_api_title -> provider_title)
reverse_mappings = {v: k for k, v in provider_mappings.items()}
# Return the mapped title if it exists, otherwise return the original
provider_title = reverse_mappings.get(media_api_title, media_api_title)
if provider_title != media_api_title:
logger.debug(
"Converted media API title to provider title: '%s' -> '%s' (provider: %s)",
media_api_title, provider_title, provider_name
media_api_title,
provider_title,
provider_name,
)
return provider_title
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.warning("Failed to load normalizer data: %s", e)
return media_api_title
def normalize_title(
title: str,
provider_name: str,
use_provider_mapping: bool = False
title: str, provider_name: str, use_provider_mapping: bool = False
) -> str:
"""
Normalize a title for search operations.
This convenience function determines the appropriate normalization direction
based on the use_provider_mapping parameter.
Args:
title: The title to normalize
provider_name: The name of the provider
use_provider_mapping: If True, convert media API title to provider title.
If False, convert provider title to media API title.
Returns:
The normalized title
Example:
>>> normalize_title_for_search("one piece", "allanime", use_provider_mapping=True)
"1P"
@@ -229,10 +224,10 @@ def normalize_title(
def get_available_providers() -> list[str]:
"""
Get a list of all available providers in the normalizer data.
Returns:
List of provider names that have mappings defined
Example:
>>> get_available_providers()
['allanime', 'hianime', 'animepahe']
@@ -248,7 +243,7 @@ def get_available_providers() -> list[str]:
def clear_cache() -> None:
"""
Clear the internal cache for normalizer data.
This is useful for testing or when the normalizer.json file has been updated
and you want to reload the data.
"""
@@ -260,13 +255,13 @@ def clear_cache() -> None:
def get_provider_mappings(provider_name: str) -> Dict[str, str]:
"""
Get all title mappings for a specific provider.
Args:
provider_name: The name of the provider
Returns:
Dictionary mapping provider titles to media API titles
Example:
>>> mappings = get_provider_mappings("allanime")
>>> print(mappings["1P"])
@@ -283,16 +278,16 @@ def get_provider_mappings(provider_name: str) -> Dict[str, str]:
def has_mapping(title: str, provider_name: str, reverse: bool = False) -> bool:
"""
Check if a mapping exists for the given title and provider.
Args:
title: The title to check
provider_name: The name of the provider
reverse: If True, check for media API -> provider mapping.
If False, check for provider -> media API mapping.
Returns:
True if a mapping exists, False otherwise
Example:
>>> has_mapping("1P", "allanime", reverse=False)
True
@@ -303,44 +298,42 @@ def has_mapping(title: str, provider_name: str, reverse: bool = False) -> bool:
"""
try:
normalizer_data = _load_normalizer_data()
if provider_name not in normalizer_data:
return False
provider_mappings = normalizer_data[provider_name]
if reverse:
# Check if title exists as a value (media API title)
return title in provider_mappings.values()
else:
# Check if title exists as a key (provider title)
return title in provider_mappings
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.warning("Failed to load normalizer data: %s", e)
return False
def add_runtime_mapping(
provider_title: str,
media_api_title: str,
provider_name: str
provider_title: str, media_api_title: str, provider_name: str
) -> None:
"""
Add a new mapping at runtime (not persisted to file).
This is useful for adding mappings discovered during runtime that
are not present in the normalizer.json file.
Args:
provider_title: The provider-specific title
media_api_title: The media API title
provider_name: The name of the provider
Note:
This mapping is only stored in memory and will be lost when
the cache is cleared or the application restarts.
Example:
>>> add_runtime_mapping("Custom Title", "Normalized Title", "allanime")
>>> provider_title_to_media_api_title("Custom Title", "allanime")
@@ -348,18 +341,20 @@ def add_runtime_mapping(
"""
try:
normalizer_data = _load_normalizer_data()
# Initialize provider if it doesn't exist
if provider_name not in normalizer_data:
normalizer_data[provider_name] = {}
# Add the mapping
normalizer_data[provider_name][provider_title] = media_api_title
logger.info(
"Added runtime mapping: '%s' -> '%s' (provider: %s)",
provider_title, media_api_title, provider_name
provider_title,
media_api_title,
provider_name,
)
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.warning("Failed to add runtime mapping: %s", e)
logger.warning("Failed to add runtime mapping: %s", e)

View File

@@ -260,6 +260,22 @@ class AniListApi(BaseApiClient):
)
return response.json() if response else None
def transform_raw_search_data(self, raw_data: Dict) -> Optional[MediaSearchResult]:
"""
Transform raw AniList API response data into a MediaSearchResult.
Args:
raw_data: Raw response data from the AniList GraphQL API
Returns:
MediaSearchResult object or None if transformation fails
"""
try:
return mapper.to_generic_search_result(raw_data)
except Exception as e:
logger.error(f"Failed to transform raw search data: {e}")
return None
if __name__ == "__main__":
from httpx import Client

View File

@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
# Map the client name to its import path AND the config section it needs.
API_CLIENTS = {
"anilist": ("fastanime.libs.media_api.anilist.api.AniListApi", "anilist"),
"jikan": ("fastanime.libs.media_api.jikan.api.JikanApi", "jikan"), # For the future
"jikan": ("fastanime.libs.media_api.jikan.api.JikanApi", "jikan"), # For the future
}

View File

@@ -78,3 +78,16 @@ class BaseApiClient(abc.ABC):
self, params: MediaAiringScheduleParams
) -> Optional[Dict]:
pass
@abc.abstractmethod
def transform_raw_search_data(self, raw_data: Dict) -> Optional[MediaSearchResult]:
"""
Transform raw API response data into a MediaSearchResult.
Args:
raw_data: Raw response data from the API
Returns:
MediaSearchResult object or None if transformation fails
"""
pass

View File

@@ -116,7 +116,7 @@ class JikanApi(BaseApiClient):
raw_data = self._execute_request(endpoint)
if not raw_data or "data" not in raw_data:
return None
recommendations = []
for item in raw_data["data"]:
# Jikan recommendation structure has an 'entry' field with anime data
@@ -124,7 +124,7 @@ class JikanApi(BaseApiClient):
if entry:
media_item = mapper._to_generic_media_item(entry)
recommendations.append(media_item)
return recommendations
except Exception as e:
logger.error(f"Failed to fetch recommendations for media {params.id}: {e}")
@@ -137,7 +137,7 @@ class JikanApi(BaseApiClient):
raw_data = self._execute_request(endpoint)
if not raw_data:
return None
# Return the raw character data as Jikan provides it
return raw_data
except Exception as e:
@@ -153,7 +153,7 @@ class JikanApi(BaseApiClient):
raw_data = self._execute_request(endpoint)
if not raw_data or "data" not in raw_data:
return None
related_anime = []
for relation in raw_data["data"]:
entries = relation.get("entry", [])
@@ -164,9 +164,7 @@ class JikanApi(BaseApiClient):
id=entry["mal_id"],
id_mal=entry["mal_id"],
title=MediaTitle(
english=entry["name"],
romaji=entry["name"],
native=None
english=entry["name"], romaji=entry["name"], native=None
),
cover_image=MediaImage(large=""),
description=None,
@@ -176,7 +174,7 @@ class JikanApi(BaseApiClient):
user_status=None,
)
related_anime.append(media_item)
return related_anime
except Exception as e:
logger.error(f"Failed to fetch related anime for media {params.id}: {e}")
@@ -186,5 +184,7 @@ class JikanApi(BaseApiClient):
self, params: MediaAiringScheduleParams
) -> Optional[Dict]:
"""Jikan doesn't provide a direct airing schedule endpoint per anime."""
logger.warning("Jikan API does not support fetching airing schedules for individual anime.")
logger.warning(
"Jikan API does not support fetching airing schedules for individual anime."
)
return None

View File

@@ -4,6 +4,7 @@ HTML parsing utilities with optional lxml support.
This module provides comprehensive HTML parsing capabilities using either
Python's built-in html.parser or lxml for better performance when available.
"""
# TODO: Review and optimize the HTML parsing logic for better performance and flexibility.
# Consider adding more utility functions for common HTML manipulation tasks.
import logging
@@ -20,6 +21,7 @@ logger = logging.getLogger(__name__)
HAS_LXML = False
try:
from lxml import etree, html as lxml_html
HAS_LXML = True
logger.debug("lxml is available and will be used for HTML parsing")
except ImportError:
@@ -28,11 +30,11 @@ except ImportError:
class HTMLParserConfig:
"""Configuration for HTML parser selection."""
def __init__(self, use_lxml: Optional[bool] = None):
"""
Initialize parser configuration.
Args:
use_lxml: Force use of lxml (True), html.parser (False), or auto-detect (None)
"""
@@ -40,30 +42,32 @@ class HTMLParserConfig:
self.use_lxml = HAS_LXML
else:
self.use_lxml = use_lxml and HAS_LXML
if use_lxml and not HAS_LXML:
logger.warning("lxml requested but not available, falling back to html.parser")
logger.warning(
"lxml requested but not available, falling back to html.parser"
)
class HTMLParser:
"""
Comprehensive HTML parser with optional lxml support.
Provides a unified interface for HTML parsing operations regardless
of the underlying parser implementation.
"""
def __init__(self, config: Optional[HTMLParserConfig] = None):
"""Initialize the HTML parser with configuration."""
self.config = config or HTMLParserConfig()
def parse(self, html_content: str) -> Union[Any, 'ParsedHTML']:
def parse(self, html_content: str) -> Union[Any, "ParsedHTML"]:
"""
Parse HTML content and return a parsed tree.
Args:
html_content: Raw HTML string to parse
Returns:
Parsed HTML tree (lxml Element or custom ParsedHTML object)
"""
@@ -71,7 +75,7 @@ class HTMLParser:
return self._parse_with_lxml(html_content)
else:
return self._parse_with_builtin(html_content)
def _parse_with_lxml(self, html_content: str) -> Any:
"""Parse HTML using lxml."""
try:
@@ -80,8 +84,8 @@ class HTMLParser:
except Exception as e:
logger.warning(f"lxml parsing failed: {e}, falling back to html.parser")
return self._parse_with_builtin(html_content)
def _parse_with_builtin(self, html_content: str) -> 'ParsedHTML':
def _parse_with_builtin(self, html_content: str) -> "ParsedHTML":
"""Parse HTML using Python's built-in parser."""
parser = BuiltinHTMLParser()
parser.feed(html_content)
@@ -90,89 +94,89 @@ class HTMLParser:
class BuiltinHTMLParser(BaseHTMLParser):
"""Enhanced HTML parser using Python's built-in capabilities."""
def __init__(self):
super().__init__()
self.elements = []
self.current_element = None
self.element_stack = []
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]):
"""Handle opening tags."""
element = {
'tag': tag,
'attrs': dict(attrs),
'text': '',
'children': [],
'start_pos': self.getpos(),
"tag": tag,
"attrs": dict(attrs),
"text": "",
"children": [],
"start_pos": self.getpos(),
}
if self.element_stack:
self.element_stack[-1]['children'].append(element)
self.element_stack[-1]["children"].append(element)
else:
self.elements.append(element)
self.element_stack.append(element)
def handle_endtag(self, tag: str):
"""Handle closing tags."""
if self.element_stack and self.element_stack[-1]['tag'] == tag:
if self.element_stack and self.element_stack[-1]["tag"] == tag:
element = self.element_stack.pop()
element['end_pos'] = self.getpos()
element["end_pos"] = self.getpos()
def handle_data(self, data: str):
"""Handle text content."""
if self.element_stack:
self.element_stack[-1]['text'] += data
self.element_stack[-1]["text"] += data
class ParsedHTML:
"""Wrapper for parsed HTML using built-in parser."""
def __init__(self, elements: List[Dict], raw_html: str):
self.elements = elements
self.raw_html = raw_html
def find_by_id(self, element_id: str) -> Optional[Dict]:
"""Find element by ID."""
return self._find_recursive(self.elements, lambda el: el['attrs'].get('id') == element_id)
return self._find_recursive(
self.elements, lambda el: el["attrs"].get("id") == element_id
)
def find_by_class(self, class_name: str) -> List[Dict]:
"""Find elements by class name."""
results = []
self._find_all_recursive(
self.elements,
lambda el: class_name in el['attrs'].get('class', '').split(),
results
self.elements,
lambda el: class_name in el["attrs"].get("class", "").split(),
results,
)
return results
def find_by_tag(self, tag_name: str) -> List[Dict]:
"""Find elements by tag name."""
results = []
self._find_all_recursive(
self.elements,
lambda el: el['tag'].lower() == tag_name.lower(),
results
self.elements, lambda el: el["tag"].lower() == tag_name.lower(), results
)
return results
def _find_recursive(self, elements: List[Dict], condition) -> Optional[Dict]:
"""Recursively find first element matching condition."""
for element in elements:
if condition(element):
return element
result = self._find_recursive(element['children'], condition)
result = self._find_recursive(element["children"], condition)
if result:
return result
return None
def _find_all_recursive(self, elements: List[Dict], condition, results: List[Dict]):
"""Recursively find all elements matching condition."""
for element in elements:
if condition(element):
results.append(element)
self._find_all_recursive(element['children'], condition, results)
self._find_all_recursive(element["children"], condition, results)
# Global parser instance
@@ -182,62 +186,62 @@ _default_parser = HTMLParser()
def extract_attributes(html_element: str) -> Dict[str, str]:
"""
Extract attributes from an HTML element string.
Args:
html_element: HTML element as string (e.g., '<div class="test" id="main">')
Returns:
Dictionary of attribute name-value pairs
Examples:
>>> extract_attributes('<div class="test" id="main">')
{'class': 'test', 'id': 'main'}
"""
if not html_element:
return {}
# Use regex to extract attributes from HTML string
attr_pattern = r'(\w+)=(["\'])([^"\']*?)\2'
matches = re.findall(attr_pattern, html_element)
attributes = {}
for match in matches:
attr_name, _, attr_value = match
attributes[attr_name] = attr_value
# Handle attributes without quotes
unquoted_pattern = r'(\w+)=([^\s>]+)'
unquoted_pattern = r"(\w+)=([^\s>]+)"
unquoted_matches = re.findall(unquoted_pattern, html_element)
for attr_name, attr_value in unquoted_matches:
if attr_name not in attributes:
attributes[attr_name] = attr_value
return attributes
def get_element_by_id(element_id: str, html_content: str) -> Optional[str]:
"""
Get HTML element by ID.
Args:
element_id: The ID attribute value to search for
html_content: HTML content to search in
Returns:
HTML string of the element or None if not found
Examples:
>>> html = '<div id="test">Content</div>'
>>> get_element_by_id("test", html)
'<div id="test">Content</div>'
"""
parsed = _default_parser.parse(html_content)
if _default_parser.config.use_lxml and HAS_LXML:
try:
element = parsed.xpath(f'//*[@id="{element_id}"]')
if element:
return etree.tostring(element[0], encoding='unicode', method='html')
return etree.tostring(element[0], encoding="unicode", method="html")
except Exception as e:
logger.warning(f"lxml XPath search failed: {e}")
return None
@@ -245,28 +249,28 @@ def get_element_by_id(element_id: str, html_content: str) -> Optional[str]:
element = parsed.find_by_id(element_id)
if element:
return _element_to_html(element, html_content)
return None
def get_element_by_tag(tag_name: str, html_content: str) -> Optional[str]:
"""
Get first HTML element by tag name.
Args:
tag_name: The tag name to search for
html_content: HTML content to search in
Returns:
HTML string of the element or None if not found
"""
parsed = _default_parser.parse(html_content)
if _default_parser.config.use_lxml and HAS_LXML:
try:
elements = parsed.xpath(f'//{tag_name}')
elements = parsed.xpath(f"//{tag_name}")
if elements:
return etree.tostring(elements[0], encoding='unicode', method='html')
return etree.tostring(elements[0], encoding="unicode", method="html")
except Exception as e:
logger.warning(f"lxml XPath search failed: {e}")
return None
@@ -274,28 +278,28 @@ def get_element_by_tag(tag_name: str, html_content: str) -> Optional[str]:
elements = parsed.find_by_tag(tag_name)
if elements:
return _element_to_html(elements[0], html_content)
return None
def get_element_by_class(class_name: str, html_content: str) -> Optional[str]:
"""
Get first HTML element by class name.
Args:
class_name: The class name to search for
html_content: HTML content to search in
Returns:
HTML string of the element or None if not found
"""
parsed = _default_parser.parse(html_content)
if _default_parser.config.use_lxml and HAS_LXML:
try:
elements = parsed.xpath(f'//*[contains(@class, "{class_name}")]')
if elements:
return etree.tostring(elements[0], encoding='unicode', method='html')
return etree.tostring(elements[0], encoding="unicode", method="html")
except Exception as e:
logger.warning(f"lxml XPath search failed: {e}")
return None
@@ -303,109 +307,119 @@ def get_element_by_class(class_name: str, html_content: str) -> Optional[str]:
elements = parsed.find_by_class(class_name)
if elements:
return _element_to_html(elements[0], html_content)
return None
def get_elements_by_tag(tag_name: str, html_content: str) -> List[str]:
"""
Get all HTML elements by tag name.
Args:
tag_name: The tag name to search for
html_content: HTML content to search in
Returns:
List of HTML strings for matching elements
"""
parsed = _default_parser.parse(html_content)
results = []
if _default_parser.config.use_lxml and HAS_LXML:
try:
elements = parsed.xpath(f'//{tag_name}')
elements = parsed.xpath(f"//{tag_name}")
for element in elements:
results.append(etree.tostring(element, encoding='unicode', method='html'))
results.append(
etree.tostring(element, encoding="unicode", method="html")
)
except Exception as e:
logger.warning(f"lxml XPath search failed: {e}")
else:
elements = parsed.find_by_tag(tag_name)
for element in elements:
results.append(_element_to_html(element, html_content))
return results
def get_elements_by_class(class_name: str, html_content: str) -> List[str]:
"""
Get all HTML elements by class name.
Args:
class_name: The class name to search for
html_content: HTML content to search in
Returns:
List of HTML strings for matching elements
"""
parsed = _default_parser.parse(html_content)
results = []
if _default_parser.config.use_lxml and HAS_LXML:
try:
elements = parsed.xpath(f'//*[contains(@class, "{class_name}")]')
for element in elements:
results.append(etree.tostring(element, encoding='unicode', method='html'))
results.append(
etree.tostring(element, encoding="unicode", method="html")
)
except Exception as e:
logger.warning(f"lxml XPath search failed: {e}")
else:
elements = parsed.find_by_class(class_name)
for element in elements:
results.append(_element_to_html(element, html_content))
return results
def get_elements_html_by_class(class_name: str, html_content: str) -> List[str]:
"""
Get HTML strings of elements by class name.
This is an alias for get_elements_by_class for yt-dlp compatibility.
Args:
class_name: The class name to search for
html_content: HTML content to search in
Returns:
List of HTML strings for matching elements
"""
return get_elements_by_class(class_name, html_content)
def get_element_text_and_html_by_tag(tag_name: str, html_content: str) -> Tuple[Optional[str], Optional[str]]:
def get_element_text_and_html_by_tag(
tag_name: str, html_content: str
) -> Tuple[Optional[str], Optional[str]]:
"""
Get both text content and HTML of first element by tag name.
Args:
tag_name: The tag name to search for
html_content: HTML content to search in
Returns:
Tuple of (text_content, html_string) or (None, None) if not found
Examples:
>>> html = '<script>alert("test");</script>'
>>> get_element_text_and_html_by_tag("script", html)
('alert("test");', '<script>alert("test");</script>')
"""
parsed = _default_parser.parse(html_content)
if _default_parser.config.use_lxml and HAS_LXML:
try:
elements = parsed.xpath(f'//{tag_name}')
elements = parsed.xpath(f"//{tag_name}")
if elements:
element = elements[0]
text = element.text_content() if hasattr(element, 'text_content') else (element.text or '')
html_str = etree.tostring(element, encoding='unicode', method='html')
text = (
element.text_content()
if hasattr(element, "text_content")
else (element.text or "")
)
html_str = etree.tostring(element, encoding="unicode", method="html")
return text, html_str
except Exception as e:
logger.warning(f"lxml XPath search failed: {e}")
@@ -417,61 +431,63 @@ def get_element_text_and_html_by_tag(tag_name: str, html_content: str) -> Tuple[
text = _extract_text_content(element)
html_str = _element_to_html(element, html_content)
return text, html_str
return None, None
def _element_to_html(element: Dict, original_html: str) -> str:
"""
Convert parsed element back to HTML string.
This is a simplified implementation that reconstructs HTML from parsed data.
For production use, consider using lxml for better accuracy.
"""
if not element:
return ""
# Build opening tag
tag = element['tag']
attrs = element.get('attrs', {})
attr_str = ' '.join(f'{k}="{v}"' for k, v in attrs.items() if v is not None)
tag = element["tag"]
attrs = element.get("attrs", {})
attr_str = " ".join(f'{k}="{v}"' for k, v in attrs.items() if v is not None)
if attr_str:
opening_tag = f"<{tag} {attr_str}>"
else:
opening_tag = f"<{tag}>"
# Add text content
text = element.get('text', '')
text = element.get("text", "")
# Add children
children_html = ""
for child in element.get('children', []):
for child in element.get("children", []):
children_html += _element_to_html(child, original_html)
# Build closing tag
closing_tag = f"</{tag}>"
return f"{opening_tag}{text}{children_html}{closing_tag}"
def _extract_text_content(element: Dict) -> str:
"""Extract all text content from element and its children."""
text = element.get('text', '')
for child in element.get('children', []):
text = element.get("text", "")
for child in element.get("children", []):
text += _extract_text_content(child)
return text
def configure_parser(use_lxml: Optional[bool] = None) -> None:
"""
Configure the global HTML parser.
Args:
use_lxml: Force use of lxml (True), html.parser (False), or auto-detect (None)
"""
global _default_parser
_default_parser = HTMLParser(HTMLParserConfig(use_lxml))
logger.info(f"HTML parser configured: {'lxml' if _default_parser.config.use_lxml else 'html.parser'}")
logger.info(
f"HTML parser configured: {'lxml' if _default_parser.config.use_lxml else 'html.parser'}"
)

View File

@@ -12,150 +12,147 @@ from typing import List, Optional
class UserAgentGenerator:
"""
Generator for realistic user agent strings.
Provides a variety of common user agents from different browsers
and operating systems to help avoid detection.
"""
# Common user agents for different browsers and OS combinations
USER_AGENTS = [
# Chrome on Windows
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
# Chrome on macOS
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36",
# Chrome on Linux
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36",
# Firefox on Windows
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:123.0) Gecko/20100101 Firefox/123.0",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:122.0) Gecko/20100101 Firefox/122.0",
# Firefox on macOS
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:123.0) Gecko/20100101 Firefox/123.0",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:122.0) Gecko/20100101 Firefox/122.0",
# Firefox on Linux
"Mozilla/5.0 (X11; Linux x86_64; rv:123.0) Gecko/20100101 Firefox/123.0",
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:123.0) Gecko/20100101 Firefox/123.0",
# Safari on macOS
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.3 Safari/605.1.15",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.2 Safari/605.1.15",
# Edge on Windows
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36 Edg/122.0.0.0",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36 Edg/121.0.0.0",
# Mobile Chrome (Android)
"Mozilla/5.0 (Linux; Android 14; SM-G998B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Mobile Safari/537.36",
"Mozilla/5.0 (Linux; Android 13; Pixel 7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Mobile Safari/537.36",
# Mobile Safari (iOS)
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_3 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.3 Mobile/15E148 Safari/604.1",
"Mozilla/5.0 (iPad; CPU OS 17_3 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.3 Mobile/15E148 Safari/604.1",
]
# Browser-specific user agents for when you need a specific browser
CHROME_USER_AGENTS = [ua for ua in USER_AGENTS if "Chrome" in ua and "Edg" not in ua]
CHROME_USER_AGENTS = [
ua for ua in USER_AGENTS if "Chrome" in ua and "Edg" not in ua
]
FIREFOX_USER_AGENTS = [ua for ua in USER_AGENTS if "Firefox" in ua]
SAFARI_USER_AGENTS = [ua for ua in USER_AGENTS if "Safari" in ua and "Chrome" not in ua]
SAFARI_USER_AGENTS = [
ua for ua in USER_AGENTS if "Safari" in ua and "Chrome" not in ua
]
EDGE_USER_AGENTS = [ua for ua in USER_AGENTS if "Edg" in ua]
# Platform-specific user agents
WINDOWS_USER_AGENTS = [ua for ua in USER_AGENTS if "Windows NT" in ua]
MACOS_USER_AGENTS = [ua for ua in USER_AGENTS if "Macintosh" in ua]
LINUX_USER_AGENTS = [ua for ua in USER_AGENTS if "Linux" in ua and "Android" not in ua]
LINUX_USER_AGENTS = [
ua for ua in USER_AGENTS if "Linux" in ua and "Android" not in ua
]
MOBILE_USER_AGENTS = [ua for ua in USER_AGENTS if "Mobile" in ua or "Android" in ua]
def __init__(self, seed: Optional[int] = None):
"""
Initialize the user agent generator.
Args:
seed: Random seed for reproducible results (optional)
"""
if seed is not None:
random.seed(seed)
def random(self) -> str:
"""
Get a random user agent string.
Returns:
Random user agent string
"""
return random.choice(self.USER_AGENTS)
def random_browser(self, browser: str) -> str:
"""
Get a random user agent for a specific browser.
Args:
browser: Browser name ('chrome', 'firefox', 'safari', 'edge')
Returns:
Random user agent string for the specified browser
Raises:
ValueError: If browser is not supported
"""
browser = browser.lower()
if browser == 'chrome':
if browser == "chrome":
return random.choice(self.CHROME_USER_AGENTS)
elif browser == 'firefox':
elif browser == "firefox":
return random.choice(self.FIREFOX_USER_AGENTS)
elif browser == 'safari':
elif browser == "safari":
return random.choice(self.SAFARI_USER_AGENTS)
elif browser == 'edge':
elif browser == "edge":
return random.choice(self.EDGE_USER_AGENTS)
else:
raise ValueError(f"Unsupported browser: {browser}")
def random_platform(self, platform: str) -> str:
"""
Get a random user agent for a specific platform.
Args:
platform: Platform name ('windows', 'macos', 'linux', 'mobile')
Returns:
Random user agent string for the specified platform
Raises:
ValueError: If platform is not supported
"""
platform = platform.lower()
if platform == 'windows':
if platform == "windows":
return random.choice(self.WINDOWS_USER_AGENTS)
elif platform in ('macos', 'mac'):
elif platform in ("macos", "mac"):
return random.choice(self.MACOS_USER_AGENTS)
elif platform == 'linux':
elif platform == "linux":
return random.choice(self.LINUX_USER_AGENTS)
elif platform == 'mobile':
elif platform == "mobile":
return random.choice(self.MOBILE_USER_AGENTS)
else:
raise ValueError(f"Unsupported platform: {platform}")
def add_user_agent(self, user_agent: str) -> None:
"""
Add a custom user agent to the list.
Args:
user_agent: Custom user agent string to add
"""
if user_agent not in self.USER_AGENTS:
self.USER_AGENTS.append(user_agent)
def get_all(self) -> List[str]:
"""
Get all available user agent strings.
Returns:
List of all user agent strings
"""
@@ -169,10 +166,10 @@ _default_generator = UserAgentGenerator()
def random_user_agent() -> str:
"""
Get a random user agent string using the default generator.
Returns:
Random user agent string
Examples:
>>> ua = random_user_agent()
>>> "Mozilla" in ua
@@ -184,10 +181,10 @@ def random_user_agent() -> str:
def random_user_agent_browser(browser: str) -> str:
"""
Get a random user agent for a specific browser.
Args:
browser: Browser name ('chrome', 'firefox', 'safari', 'edge')
Returns:
Random user agent string for the specified browser
"""
@@ -197,10 +194,10 @@ def random_user_agent_browser(browser: str) -> str:
def random_user_agent_platform(platform: str) -> str:
"""
Get a random user agent for a specific platform.
Args:
platform: Platform name ('windows', 'macos', 'linux', 'mobile')
Returns:
Random user agent string for the specified platform
"""
@@ -210,7 +207,7 @@ def random_user_agent_platform(platform: str) -> str:
def set_user_agent_seed(seed: int) -> None:
"""
Set the random seed for user agent generation.
Args:
seed: Random seed value
"""
@@ -221,7 +218,7 @@ def set_user_agent_seed(seed: int) -> None:
def add_custom_user_agent(user_agent: str) -> None:
"""
Add a custom user agent to the default generator.
Args:
user_agent: Custom user agent string to add
"""
@@ -231,7 +228,7 @@ def add_custom_user_agent(user_agent: str) -> None:
def get_all_user_agents() -> List[str]:
"""
Get all available user agent strings from the default generator.
Returns:
List of all user agent strings
"""

View File

@@ -6,21 +6,21 @@ that was previously sourced from yt-dlp.
"""
import string
from typing import Union,Optional
from typing import Optional
def encode_base_n(num: int, n: int, table: Optional[str] = None) -> str:
"""
Encode a number in base-n representation.
Args:
num: The number to encode
n: The base to use for encoding
table: Custom character table (optional)
Returns:
String representation of the number in base-n
Examples:
>>> encode_base_n(255, 16)
'ff'
@@ -30,39 +30,39 @@ def encode_base_n(num: int, n: int, table: Optional[str] = None) -> str:
if table is None:
# Default table: 0-9, a-z
table = string.digits + string.ascii_lowercase
if not 2 <= n <= len(table):
raise ValueError(f"Base must be between 2 and {len(table)}")
if num == 0:
return table[0]
result = []
is_negative = num < 0
num = abs(num)
while num > 0:
result.append(table[num % n])
num //= n
if is_negative:
result.append('-')
return ''.join(reversed(result))
result.append("-")
return "".join(reversed(result))
def decode_base_n(encoded: str, n: int, table: Optional[str] = None) -> int:
"""
Decode a base-n encoded string back to an integer.
Args:
encoded: The base-n encoded string
n: The base used for encoding
table: Custom character table (optional)
Returns:
The decoded integer
Examples:
>>> decode_base_n('ff', 16)
255
@@ -71,129 +71,135 @@ def decode_base_n(encoded: str, n: int, table: Optional[str] = None) -> int:
"""
if table is None:
table = string.digits + string.ascii_lowercase
if not 2 <= n <= len(table):
raise ValueError(f"Base must be between 2 and {len(table)}")
if not encoded:
return 0
is_negative = encoded.startswith('-')
is_negative = encoded.startswith("-")
if is_negative:
encoded = encoded[1:]
result = 0
for i, char in enumerate(reversed(encoded.lower())):
if char not in table:
raise ValueError(f"Invalid character '{char}' for base {n}")
digit_value = table.index(char)
if digit_value >= n:
raise ValueError(f"Invalid digit '{char}' for base {n}")
result += digit_value * (n ** i)
result += digit_value * (n**i)
return -result if is_negative else result
def url_encode(text: str, safe: str = '') -> str:
def url_encode(text: str, safe: str = "") -> str:
"""
URL encode a string.
Args:
text: Text to encode
safe: Characters that should not be encoded
Returns:
URL encoded string
"""
import urllib.parse
return urllib.parse.quote(text, safe=safe)
def url_decode(text: str) -> str:
"""
URL decode a string.
Args:
text: URL encoded text to decode
Returns:
Decoded string
"""
import urllib.parse
return urllib.parse.unquote(text)
def html_unescape(text: str) -> str:
"""
Unescape HTML entities in text.
Args:
text: Text containing HTML entities
Returns:
Text with HTML entities unescaped
Examples:
>>> html_unescape('&quot;Hello&quot; &amp; &lt;World&gt;')
'"Hello" & <World>'
"""
import html
return html.unescape(text)
def strip_tags(html_content: str) -> str:
"""
Remove all HTML tags from content, leaving only text.
Args:
html_content: HTML content with tags
Returns:
Plain text with tags removed
Examples:
>>> strip_tags('<p>Hello <b>world</b>!</p>')
'Hello world!'
"""
import re
return re.sub(r'<[^>]+>', '', html_content)
return re.sub(r"<[^>]+>", "", html_content)
def normalize_whitespace(text: str) -> str:
"""
Normalize whitespace in text by collapsing multiple spaces and removing leading/trailing whitespace.
Args:
text: Text to normalize
Returns:
Text with normalized whitespace
Examples:
>>> normalize_whitespace(' Hello world \\n\\t ')
'Hello world'
"""
import re
return re.sub(r'\s+', ' ', text.strip())
return re.sub(r"\s+", " ", text.strip())
def extract_domain(url: str) -> str:
"""
Extract domain from a URL.
Args:
url: Full URL
Returns:
Domain portion of the URL
Examples:
>>> extract_domain('https://example.com/path?query=1')
'example.com'
"""
import urllib.parse
parsed = urllib.parse.urlparse(url)
return parsed.netloc
@@ -201,38 +207,40 @@ def extract_domain(url: str) -> str:
def join_url(base: str, path: str) -> str:
"""
Join a base URL with a path.
Args:
base: Base URL
path: Path to join
Returns:
Combined URL
Examples:
>>> join_url('https://example.com', '/api/data')
'https://example.com/api/data'
"""
import urllib.parse
return urllib.parse.urljoin(base, path)
def parse_query_string(query: str) -> dict:
"""
Parse a query string into a dictionary.
Args:
query: Query string (with or without leading '?')
Returns:
Dictionary of query parameters
Examples:
>>> parse_query_string('?name=John&age=30')
{'name': ['John'], 'age': ['30']}
"""
import urllib.parse
if query.startswith('?'):
if query.startswith("?"):
query = query[1:]
return urllib.parse.parse_qs(query)
@@ -240,19 +248,19 @@ def parse_query_string(query: str) -> dict:
def build_query_string(params: dict) -> str:
"""
Build a query string from a dictionary of parameters.
Args:
params: Dictionary of parameters
Returns:
URL-encoded query string
Examples:
>>> build_query_string({'name': 'John', 'age': 30})
'name=John&age=30'
"""
import urllib.parse
# Handle both single values and lists
normalized_params = {}
for key, value in params.items():
@@ -260,5 +268,5 @@ def build_query_string(params: dict) -> str:
normalized_params[key] = value
else:
normalized_params[key] = [str(value)]
return urllib.parse.urlencode(normalized_params, doseq=True)

View File

@@ -55,7 +55,7 @@ class BaseSelector(ABC):
# Default implementation: single selection in a loop
selected = []
remaining_choices = choices.copy()
while remaining_choices:
choice = self.choose(
f"{prompt} (Select multiple, empty to finish)",
@@ -63,16 +63,18 @@ class BaseSelector(ABC):
preview=preview,
header=header,
)
if not choice or choice == "[DONE] Finish selection":
break
selected.append(choice)
remaining_choices.remove(choice)
if not self.confirm(f"Selected: {', '.join(selected)}. Continue selecting?", default=True):
if not self.confirm(
f"Selected: {', '.join(selected)}. Continue selecting?", default=True
):
break
return selected
@abstractmethod
@@ -102,3 +104,28 @@ class BaseSelector(ABC):
The string entered by the user.
"""
pass
def search(
self,
prompt: str,
search_command: str,
*,
preview: Optional[str] = None,
header: Optional[str] = None,
) -> str | None:
"""
Provides dynamic search functionality that reloads results based on user input.
Args:
prompt: The message to display to the user.
search_command: The command to execute for searching/reloading results.
preview: An optional command or string for a preview window.
header: An optional header to display above the choices.
Returns:
The string of the chosen item.
Raises:
NotImplementedError: If the selector doesn't support dynamic search.
"""
raise NotImplementedError("Dynamic search is not supported by this selector")

View File

@@ -77,9 +77,11 @@ class FzfSelector(BaseSelector):
)
if result.returncode != 0:
return []
# Split the output by newlines and filter out empty strings
selections = [line.strip() for line in result.stdout.strip().split('\n') if line.strip()]
selections = [
line.strip() for line in result.stdout.strip().split("\n") if line.strip()
]
return selections
def confirm(self, prompt, *, default=False):
@@ -112,3 +114,30 @@ class FzfSelector(BaseSelector):
# The output contains the selection (if any) and the query on the last line
lines = result.stdout.strip().splitlines()
return lines[-1] if lines else (default or "")
def search(self, prompt, search_command, *, preview=None, header=None):
"""Enhanced search using fzf's --reload flag for dynamic search."""
commands = [
self.executable,
"--prompt",
f"{prompt.title()}: ",
"--header",
header or self.header,
"--header-first",
"--bind",
f"change:reload({search_command})",
"--ansi",
]
if preview:
commands.extend(["--preview", preview])
result = subprocess.run(
commands,
input="",
stdout=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
return None
return result.stdout.strip()

View File

@@ -7,109 +7,109 @@ from fastanime.cli.utils.parser import parse_episode_range
class TestParseEpisodeRange:
"""Test cases for the parse_episode_range function."""
@pytest.fixture
def episodes(self):
"""Sample episode list for testing."""
return ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
def test_no_range_returns_all_episodes(self, episodes):
"""Test that None or empty range returns all episodes."""
result = list(parse_episode_range(None, episodes))
assert result == episodes
def test_colon_only_returns_all_episodes(self, episodes):
"""Test that ':' returns all episodes."""
result = list(parse_episode_range(":", episodes))
assert result == episodes
def test_start_end_range(self, episodes):
"""Test start:end range format."""
result = list(parse_episode_range("2:5", episodes))
assert result == ["3", "4", "5"]
def test_start_only_range(self, episodes):
"""Test start: range format."""
result = list(parse_episode_range("5:", episodes))
assert result == ["6", "7", "8", "9", "10"]
def test_end_only_range(self, episodes):
"""Test :end range format."""
result = list(parse_episode_range(":3", episodes))
assert result == ["1", "2", "3"]
def test_start_end_step_range(self, episodes):
"""Test start:end:step range format."""
result = list(parse_episode_range("2:8:2", episodes))
assert result == ["3", "5", "7"]
def test_single_number_range(self, episodes):
"""Test single number format (start from index)."""
result = list(parse_episode_range("5", episodes))
assert result == ["6", "7", "8", "9", "10"]
def test_empty_start_end_in_three_part_range_raises_error(self, episodes):
"""Test that empty parts in start:end:step format raise error."""
with pytest.raises(ValueError, match="When using 3 parts"):
list(parse_episode_range(":5:2", episodes))
with pytest.raises(ValueError, match="When using 3 parts"):
list(parse_episode_range("2::2", episodes))
with pytest.raises(ValueError, match="When using 3 parts"):
list(parse_episode_range("2:5:", episodes))
def test_invalid_integer_raises_error(self, episodes):
"""Test that invalid integers raise ValueError."""
with pytest.raises(ValueError, match="Must be a valid integer"):
list(parse_episode_range("abc", episodes))
with pytest.raises(ValueError, match="Start and end must be valid integers"):
list(parse_episode_range("2:abc", episodes))
with pytest.raises(ValueError, match="All parts must be valid integers"):
list(parse_episode_range("2:5:abc", episodes))
def test_zero_step_raises_error(self, episodes):
"""Test that zero step raises ValueError."""
with pytest.raises(ValueError, match="Step value must be positive"):
list(parse_episode_range("2:5:0", episodes))
def test_negative_step_raises_error(self, episodes):
"""Test that negative step raises ValueError."""
with pytest.raises(ValueError, match="Step value must be positive"):
list(parse_episode_range("2:5:-1", episodes))
def test_too_many_colons_raises_error(self, episodes):
"""Test that too many colons raise ValueError."""
with pytest.raises(ValueError, match="Too many colon separators"):
list(parse_episode_range("2:5:7:9", episodes))
def test_edge_case_empty_list(self):
"""Test behavior with empty episode list."""
result = list(parse_episode_range(":", []))
assert result == []
def test_edge_case_single_episode(self):
"""Test behavior with single episode."""
episodes = ["1"]
result = list(parse_episode_range(":", episodes))
assert result == ["1"]
result = list(parse_episode_range("0:1", episodes))
assert result == ["1"]
def test_numerical_sorting(self):
"""Test that episodes are sorted numerically, not lexicographically."""
episodes = ["10", "2", "1", "11", "3"]
result = list(parse_episode_range(":", episodes))
assert result == ["1", "2", "3", "10", "11"]
def test_index_out_of_bounds_behavior(self, episodes):
"""Test behavior when indices exceed available episodes."""
# Python slicing handles out-of-bounds gracefully
result = list(parse_episode_range("15:", episodes))
assert result == [] # No episodes beyond index 15
result = list(parse_episode_range(":20", episodes))
assert result == episodes # All episodes (slice stops at end)

View File

@@ -5,14 +5,13 @@ Tests for the TorrentDownloader class.
import tempfile
import unittest
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
from unittest.mock import Mock, patch
import pytest
from fastanime.core.downloader.torrents import (
TorrentDownloader,
TorrentDownloadError,
LIBTORRENT_AVAILABLE
LIBTORRENT_AVAILABLE,
)
from fastanime.core.exceptions import DependencyNotFoundError
@@ -27,12 +26,13 @@ class TestTorrentDownloader(unittest.TestCase):
download_path=self.temp_dir,
max_upload_rate=100,
max_download_rate=200,
max_connections=50
max_connections=50,
)
def tearDown(self):
"""Clean up test fixtures."""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_init(self):
@@ -47,48 +47,50 @@ class TestTorrentDownloader(unittest.TestCase):
"""Test that download directory is created if it doesn't exist."""
non_existent_dir = self.temp_dir / "new_dir"
self.assertFalse(non_existent_dir.exists())
downloader = TorrentDownloader(download_path=non_existent_dir)
self.assertTrue(non_existent_dir.exists())
@patch('fastanime.core.downloader.torrents.shutil.which')
@patch("fastanime.core.downloader.torrents.shutil.which")
def test_download_with_webtorrent_cli_not_available(self, mock_which):
"""Test webtorrent CLI fallback when not available."""
mock_which.return_value = None
with self.assertRaises(DependencyNotFoundError) as context:
self.downloader.download_with_webtorrent_cli("magnet:test")
self.assertIn("webtorrent CLI is not available", str(context.exception))
@patch('fastanime.core.downloader.torrents.subprocess.run')
@patch('fastanime.core.downloader.torrents.shutil.which')
@patch("fastanime.core.downloader.torrents.subprocess.run")
@patch("fastanime.core.downloader.torrents.shutil.which")
def test_download_with_webtorrent_cli_success(self, mock_which, mock_run):
"""Test successful webtorrent CLI download."""
mock_which.return_value = "/usr/bin/webtorrent"
mock_result = Mock()
mock_result.stdout = f"Downloaded test-file to {self.temp_dir}/test-file"
mock_run.return_value = mock_result
# Create a dummy file to simulate download
test_file = self.temp_dir / "test-file"
test_file.touch()
result = self.downloader.download_with_webtorrent_cli("magnet:test")
mock_run.assert_called_once()
self.assertEqual(result, test_file)
@patch('fastanime.core.downloader.torrents.subprocess.run')
@patch('fastanime.core.downloader.torrents.shutil.which')
@patch("fastanime.core.downloader.torrents.subprocess.run")
@patch("fastanime.core.downloader.torrents.shutil.which")
def test_download_with_webtorrent_cli_failure(self, mock_which, mock_run):
"""Test webtorrent CLI download failure."""
mock_which.return_value = "/usr/bin/webtorrent"
mock_run.side_effect = subprocess.CalledProcessError(1, "webtorrent", stderr="Error")
mock_run.side_effect = subprocess.CalledProcessError(
1, "webtorrent", stderr="Error"
)
with self.assertRaises(TorrentDownloadError) as context:
self.downloader.download_with_webtorrent_cli("magnet:test")
self.assertIn("webtorrent CLI failed", str(context.exception))
@unittest.skipUnless(LIBTORRENT_AVAILABLE, "libtorrent not available")
@@ -103,61 +105,60 @@ class TestTorrentDownloader(unittest.TestCase):
with self.assertRaises(DependencyNotFoundError):
self.downloader._setup_libtorrent_session()
@patch('fastanime.core.downloader.torrents.LIBTORRENT_AVAILABLE', False)
@patch("fastanime.core.downloader.torrents.LIBTORRENT_AVAILABLE", False)
def test_download_with_libtorrent_not_available(self):
"""Test libtorrent download when not available."""
with self.assertRaises(DependencyNotFoundError) as context:
self.downloader.download_with_libtorrent("magnet:test")
self.assertIn("libtorrent is not available", str(context.exception))
def test_progress_callback(self):
"""Test progress callback functionality."""
callback_mock = Mock()
downloader = TorrentDownloader(
download_path=self.temp_dir,
progress_callback=callback_mock
download_path=self.temp_dir, progress_callback=callback_mock
)
# The callback should be stored
self.assertEqual(downloader.progress_callback, callback_mock)
@patch.object(TorrentDownloader, 'download_with_webtorrent_cli')
@patch.object(TorrentDownloader, 'download_with_libtorrent')
@patch.object(TorrentDownloader, "download_with_webtorrent_cli")
@patch.object(TorrentDownloader, "download_with_libtorrent")
def test_download_prefers_libtorrent(self, mock_libtorrent, mock_webtorrent):
"""Test that download method prefers libtorrent by default."""
mock_libtorrent.return_value = self.temp_dir / "test"
with patch('fastanime.core.downloader.torrents.LIBTORRENT_AVAILABLE', True):
with patch("fastanime.core.downloader.torrents.LIBTORRENT_AVAILABLE", True):
result = self.downloader.download("magnet:test", prefer_libtorrent=True)
mock_libtorrent.assert_called_once()
mock_webtorrent.assert_not_called()
@patch.object(TorrentDownloader, 'download_with_webtorrent_cli')
@patch.object(TorrentDownloader, 'download_with_libtorrent')
@patch.object(TorrentDownloader, "download_with_webtorrent_cli")
@patch.object(TorrentDownloader, "download_with_libtorrent")
def test_download_fallback_to_webtorrent(self, mock_libtorrent, mock_webtorrent):
"""Test fallback to webtorrent when libtorrent fails."""
mock_libtorrent.side_effect = DependencyNotFoundError("libtorrent not found")
mock_webtorrent.return_value = self.temp_dir / "test"
with patch('fastanime.core.downloader.torrents.LIBTORRENT_AVAILABLE', True):
with patch("fastanime.core.downloader.torrents.LIBTORRENT_AVAILABLE", True):
result = self.downloader.download("magnet:test")
mock_libtorrent.assert_called_once()
mock_webtorrent.assert_called_once()
self.assertEqual(result, self.temp_dir / "test")
@patch.object(TorrentDownloader, 'download_with_webtorrent_cli')
@patch.object(TorrentDownloader, 'download_with_libtorrent')
@patch.object(TorrentDownloader, "download_with_webtorrent_cli")
@patch.object(TorrentDownloader, "download_with_libtorrent")
def test_download_all_methods_fail(self, mock_libtorrent, mock_webtorrent):
"""Test when all download methods fail."""
mock_libtorrent.side_effect = DependencyNotFoundError("libtorrent not found")
mock_webtorrent.side_effect = DependencyNotFoundError("webtorrent not found")
with self.assertRaises(TorrentDownloadError) as context:
self.downloader.download("magnet:test")
self.assertIn("All torrent download methods failed", str(context.exception))
def test_magnet_link_detection(self):
@@ -165,12 +166,12 @@ class TestTorrentDownloader(unittest.TestCase):
magnet_link = "magnet:?xt=urn:btih:test"
http_link = "http://example.com/test.torrent"
file_path = "/path/to/test.torrent"
# These would be tested in integration tests with actual libtorrent
# Here we just verify the method exists and handles different input types
self.assertTrue(magnet_link.startswith('magnet:'))
self.assertTrue(http_link.startswith(('http://', 'https://')))
self.assertFalse(file_path.startswith(('magnet:', 'http://', 'https://')))
self.assertTrue(magnet_link.startswith("magnet:"))
self.assertTrue(http_link.startswith(("http://", "https://")))
self.assertFalse(file_path.startswith(("magnet:", "http://", "https://")))
class TestLegacyFunction(unittest.TestCase):
@@ -183,18 +184,21 @@ class TestLegacyFunction(unittest.TestCase):
def tearDown(self):
"""Clean up test fixtures."""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
@patch.object(TorrentDownloader, 'download_with_webtorrent_cli')
@patch.object(TorrentDownloader, "download_with_webtorrent_cli")
def test_legacy_function(self, mock_download):
"""Test the legacy download_torrent_with_webtorrent_cli function."""
from fastanime.core.downloader.torrents import download_torrent_with_webtorrent_cli
from fastanime.core.downloader.torrents import (
download_torrent_with_webtorrent_cli,
)
test_path = self.temp_dir / "test.mkv"
mock_download.return_value = test_path
result = download_torrent_with_webtorrent_cli(test_path, "magnet:test")
mock_download.assert_called_once_with("magnet:test")
self.assertEqual(result, test_path)
@@ -202,4 +206,5 @@ class TestLegacyFunction(unittest.TestCase):
if __name__ == "__main__":
# Add subprocess import for the test
import subprocess
unittest.main()