1339 lines
55 KiB
Python
1339 lines
55 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
OTA管理器模块
|
||
提供OTA升级的状态管理和主要功能封装
|
||
"""
|
||
import binascii
|
||
import hashlib
|
||
import re
|
||
import threading
|
||
import os
|
||
import json
|
||
import shutil
|
||
from urllib.parse import urlparse, unquote
|
||
|
||
import requests
|
||
from maix import time
|
||
import config
|
||
from hardware import hardware_manager
|
||
from network import network_manager
|
||
from logger_manager import logger_manager
|
||
from power import get_bus_voltage, voltage_to_percent
|
||
|
||
|
||
# 延迟导入避免循环依赖
|
||
# from network import network_manager
|
||
|
||
|
||
class OTAManager:
|
||
"""OTA升级管理器(单例)"""
|
||
_instance = None
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super(OTAManager, cls).__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
# 私有状态
|
||
self._update_thread_started = False
|
||
self._ota_in_progress = 0
|
||
self._ota_url = None
|
||
self._ota_mode = None
|
||
self._lock = threading.Lock()
|
||
self._is_https = False
|
||
self._initialized = True
|
||
|
||
# ==================== 状态访问(只读属性)====================
|
||
|
||
@property
|
||
def logger(self):
|
||
"""获取 logger 对象"""
|
||
return logger_manager.logger
|
||
|
||
@property
|
||
def update_thread_started(self):
|
||
"""OTA线程是否已启动"""
|
||
return self._update_thread_started
|
||
|
||
@property
|
||
def ota_in_progress(self):
|
||
"""OTA是否正在进行"""
|
||
with self._lock:
|
||
return self._ota_in_progress > 0
|
||
|
||
@property
|
||
def ota_url(self):
|
||
"""当前OTA URL"""
|
||
return self._ota_url
|
||
|
||
@property
|
||
def ota_mode(self):
|
||
"""当前OTA模式"""
|
||
return self._ota_mode
|
||
|
||
# ==================== 内部状态管理方法 ====================
|
||
|
||
def _start_update_thread(self):
|
||
"""标记OTA线程已启动(内部方法)"""
|
||
with self._lock:
|
||
if self._update_thread_started:
|
||
return False
|
||
self._update_thread_started = True
|
||
return True
|
||
|
||
def _stop_update_thread(self):
|
||
"""标记OTA线程已停止(内部方法)"""
|
||
with self._lock:
|
||
self._update_thread_started = False
|
||
|
||
def _begin_ota(self, url=None, mode=None):
|
||
"""开始OTA(增加计数,内部方法)"""
|
||
with self._lock:
|
||
self._ota_in_progress += 1
|
||
if url:
|
||
self._ota_url = url
|
||
if mode:
|
||
self._ota_mode = mode
|
||
|
||
def _end_ota(self):
|
||
"""结束OTA(减少计数,内部方法)"""
|
||
with self._lock:
|
||
self._ota_in_progress = max(0, self._ota_in_progress - 1)
|
||
|
||
def _set_ota_url(self, url):
|
||
"""设置OTA URL(内部方法)"""
|
||
with self._lock:
|
||
self._ota_url = url
|
||
|
||
def _set_ota_mode(self, mode):
|
||
"""设置OTA模式(内部方法)"""
|
||
with self._lock:
|
||
self._ota_mode = mode
|
||
|
||
|
||
def is_archive_file(self, filename):
|
||
"""
|
||
检查文件是否是ZIP压缩包(通过扩展名判断)
|
||
约定:上传的代码要么是ZIP压缩包(.zip),要么是直接的PY文件(.py)
|
||
|
||
Returns:
|
||
(is_archive, archive_type): (True/False, 'zip'/None)
|
||
"""
|
||
if not os.path.exists(filename):
|
||
return False, None
|
||
|
||
filename_lower = filename.lower()
|
||
if filename_lower.endswith('.zip'):
|
||
self.logger.info(f"[EXTRACT] 检测到ZIP文件(扩展名: .zip)")
|
||
return True, 'zip'
|
||
|
||
self.logger.info(f"[EXTRACT] 不是ZIP格式(扩展名: {os.path.splitext(filename)[1] or '无'})")
|
||
return False, None
|
||
|
||
def extract_zip_archive(self, archive_path, extract_to_dir=None, target_file=None):
|
||
"""
|
||
使用系统 unzip 命令解压ZIP文件
|
||
|
||
Args:
|
||
archive_path: ZIP文件路径
|
||
extract_to_dir: 解压到的目录(如果为None,解压到压缩包所在目录)
|
||
target_file: 目标文件名(如'main.py'),如果指定,只提取该文件;None表示解压所有文件
|
||
|
||
Returns:
|
||
(success, extracted_dir): 成功则返回(True, 解压目录路径),失败返回(False, None)
|
||
"""
|
||
if extract_to_dir is None:
|
||
extract_to_dir = os.path.dirname(archive_path) or '/tmp'
|
||
|
||
self.logger.info(f"[EXTRACT] 开始解压ZIP文件: {archive_path}")
|
||
|
||
try:
|
||
os.makedirs(extract_to_dir, exist_ok=True)
|
||
|
||
if target_file:
|
||
cmd = f"unzip -q -o '{archive_path}' '{target_file}' -d '{extract_to_dir}' 2>&1"
|
||
else:
|
||
cmd = f"unzip -q -o '{archive_path}' -d '{extract_to_dir}' 2>&1"
|
||
|
||
result = os.system(cmd)
|
||
|
||
if result != 0:
|
||
self.logger.warning(f"[EXTRACT] 直接解压目标文件失败,尝试解压所有文件...")
|
||
cmd_all = f"unzip -q -o '{archive_path}' -d '{extract_to_dir}' 2>&1"
|
||
result_all = os.system(cmd_all)
|
||
|
||
if result_all != 0:
|
||
self.logger.error(f"[EXTRACT] 解压失败,退出码: {result_all}")
|
||
return False, None
|
||
|
||
return True, extract_to_dir
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"[EXTRACT] 解压过程出错: {e}")
|
||
return False, None
|
||
|
||
def apply_ota_and_reboot(self, ota_url=None, downloaded_file=None):
|
||
"""
|
||
OTA 文件下载成功后:
|
||
1. 备份应用目录中的所有代码文件
|
||
2. 如果是ZIP则解压,如果是单个文件则直接使用
|
||
3. 复制/覆盖所有更新文件到应用目录
|
||
4. 重启设备
|
||
|
||
Args:
|
||
ota_url: OTA URL(用于记录)
|
||
downloaded_file: 下载的文件路径(如果为None,使用默认的main_tmp.py)
|
||
"""
|
||
|
||
# 在调用前设置状态
|
||
if ota_url:
|
||
self._set_ota_url(ota_url)
|
||
|
||
if downloaded_file is None:
|
||
downloaded_file = config.LOCAL_FILENAME
|
||
|
||
ota_pending = f"{config.APP_DIR}/ota_pending.json"
|
||
|
||
self.logger.info(f"[OTA] 准备应用OTA更新,下载文件: {downloaded_file}")
|
||
|
||
try:
|
||
if not os.path.exists(downloaded_file):
|
||
self.logger.error(f"[OTA] 错误:{downloaded_file} 不存在")
|
||
return False
|
||
|
||
# ====== 第一步:如果是 AEAD 加密包,先解密成临时 zip(再走原有 unzip 流程) ======
|
||
downloaded_file_original = downloaded_file
|
||
decrypted_tmp_zip = None
|
||
try:
|
||
magic = b"AROTAE1" # must match packager/C++ side
|
||
is_enc_ext = downloaded_file.lower().endswith((".enc", ".zip.enc"))
|
||
is_enc_magic = False
|
||
try:
|
||
with open(downloaded_file, "rb") as f:
|
||
head = f.read(len(magic))
|
||
is_enc_magic = (head == magic)
|
||
except Exception:
|
||
is_enc_magic = False
|
||
|
||
if is_enc_ext or is_enc_magic:
|
||
# Choose output zip path (same dir)
|
||
tmp_zip = downloaded_file
|
||
if tmp_zip.lower().endswith(".zip.enc"):
|
||
tmp_zip = tmp_zip[:-4] # remove ".enc" -> ".zip"
|
||
elif tmp_zip.lower().endswith(".enc"):
|
||
tmp_zip = tmp_zip[:-4]
|
||
if not tmp_zip.lower().endswith(".zip"):
|
||
tmp_zip = tmp_zip + ".zip"
|
||
else:
|
||
tmp_zip = tmp_zip + ".zip"
|
||
|
||
decrypted_tmp_zip = tmp_zip
|
||
|
||
# Remove stale tmp if exists
|
||
try:
|
||
if os.path.exists(decrypted_tmp_zip):
|
||
os.remove(decrypted_tmp_zip)
|
||
except Exception:
|
||
pass
|
||
|
||
self.logger.info(f"[OTA] 检测到加密包,开始解密: {downloaded_file} -> {decrypted_tmp_zip}")
|
||
ok = False
|
||
try:
|
||
core = getattr(network_manager, "_netcore", None)
|
||
if core and hasattr(core, "decrypt_ota_file"):
|
||
ok = bool(core.decrypt_ota_file(downloaded_file, decrypted_tmp_zip))
|
||
else:
|
||
import archery_netcore as _netcore
|
||
ok = bool(_netcore.decrypt_ota_file(downloaded_file, decrypted_tmp_zip))
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 解密异常: {e}")
|
||
ok = False
|
||
|
||
if not ok or (not os.path.exists(decrypted_tmp_zip)):
|
||
self.logger.error("[OTA] 解密失败,终止更新")
|
||
return False
|
||
|
||
downloaded_file = decrypted_tmp_zip
|
||
self.logger.info(f"[OTA] 解密成功,后续使用明文ZIP: {downloaded_file}")
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 解密流程异常: {e}")
|
||
return False
|
||
|
||
# 备份
|
||
backup_base = config.BACKUP_BASE
|
||
backup_dir = None
|
||
|
||
try:
|
||
os.makedirs(backup_base, exist_ok=True)
|
||
|
||
counter_file = os.path.join(backup_base, ".counter")
|
||
try:
|
||
if os.path.exists(counter_file):
|
||
with open(counter_file, 'r') as f:
|
||
counter = int(f.read().strip()) + 1
|
||
else:
|
||
counter = 1
|
||
|
||
with open(counter_file, 'w') as f:
|
||
f.write(str(counter))
|
||
|
||
backup_dir = os.path.join(backup_base, f"backup_{counter:04d}")
|
||
self.logger.info(f"[OTA] 使用备份目录: {backup_dir} (第{counter}次OTA)")
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 生成备份目录名失败: {e},使用默认目录")
|
||
backup_dir = os.path.join(backup_base, "backup_0000")
|
||
|
||
# 清理旧备份
|
||
try:
|
||
backup_dirs = []
|
||
for item in os.listdir(backup_base):
|
||
if item == ".counter":
|
||
continue
|
||
item_path = os.path.join(backup_base, item)
|
||
if os.path.isdir(item_path) and item.startswith("backup_"):
|
||
try:
|
||
dir_num_str = item.replace("backup_", "")
|
||
dir_num = int(dir_num_str)
|
||
backup_dirs.append((item, dir_num, item_path))
|
||
except:
|
||
pass
|
||
|
||
backup_dirs.sort(key=lambda x: x[1], reverse=True)
|
||
if len(backup_dirs) > config.MAX_BACKUPS:
|
||
for item, dir_num, item_path in backup_dirs[config.MAX_BACKUPS:]:
|
||
try:
|
||
shutil.rmtree(item_path, ignore_errors=True)
|
||
self.logger.info(f"[OTA] 已删除旧备份: {item}")
|
||
except Exception as e:
|
||
self.logger.warning(f"[OTA] 删除旧备份失败: {e}")
|
||
except Exception as e:
|
||
self.logger.warning(f"[OTA] 清理旧备份时出错: {e}")
|
||
|
||
os.makedirs(backup_dir, exist_ok=True)
|
||
|
||
exclude_patterns = ['.pyc', '__pycache__', '.log', 'backups', 'ota_extract', '.bak', 'download']
|
||
backed_up_files = []
|
||
|
||
if os.path.exists(config.APP_DIR):
|
||
for root, dirs, files in os.walk(config.APP_DIR):
|
||
dirs[:] = [d for d in dirs if not any(ex in d for ex in exclude_patterns)]
|
||
|
||
for f in files:
|
||
if any(ex in f for ex in exclude_patterns):
|
||
continue
|
||
|
||
source_path = os.path.join(root, f)
|
||
rel_path = os.path.relpath(source_path, config.APP_DIR)
|
||
backup_path = os.path.join(backup_dir, rel_path)
|
||
|
||
backup_parent = os.path.dirname(backup_path)
|
||
if backup_parent != backup_dir:
|
||
os.makedirs(backup_parent, exist_ok=True)
|
||
|
||
try:
|
||
shutil.copy2(source_path, backup_path)
|
||
backed_up_files.append(rel_path)
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 备份 {rel_path} 失败: {e}")
|
||
|
||
if backed_up_files:
|
||
self.logger.info(f"[OTA] 总共备份了 {len(backed_up_files)} 个文件到 {backup_dir}")
|
||
else:
|
||
self.logger.warning(f"[OTA] 没有备份任何文件")
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 备份过程出错: {e}")
|
||
if not backup_dir:
|
||
backup_dir = None
|
||
|
||
# 检查是否是ZIP压缩包
|
||
is_archive, archive_type = self.is_archive_file(downloaded_file)
|
||
files_to_copy = []
|
||
|
||
if is_archive and archive_type == 'zip':
|
||
# 在解压前验证ZIP文件完整性
|
||
try:
|
||
with open(downloaded_file, "rb") as f:
|
||
zip_header = f.read(4)
|
||
if zip_header[:2] != b'PK':
|
||
self.logger.error(f"[OTA] ZIP文件头验证失败: {zip_header.hex()}")
|
||
return False
|
||
file_size = os.path.getsize(downloaded_file)
|
||
self.logger.info(f"[OTA] ZIP文件验证通过: 大小={file_size} bytes, 头={zip_header.hex()}")
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] ZIP文件验证异常: {e}")
|
||
return False
|
||
|
||
self.logger.info(f"[OTA] 检测到ZIP压缩包,开始解压...")
|
||
extract_dir = "/tmp/ota_extract"
|
||
try:
|
||
os.makedirs(extract_dir, exist_ok=True)
|
||
except:
|
||
extract_dir = f"{config.APP_DIR}/ota_extract"
|
||
os.makedirs(extract_dir, exist_ok=True)
|
||
|
||
success, extracted_dir = self.extract_zip_archive(
|
||
downloaded_file,
|
||
extract_to_dir=extract_dir,
|
||
target_file=None
|
||
)
|
||
|
||
if success and extracted_dir and os.path.exists(extracted_dir):
|
||
for root, dirs, files in os.walk(extracted_dir):
|
||
for f in files:
|
||
source_path = os.path.join(root, f)
|
||
rel_path = os.path.relpath(source_path, extracted_dir)
|
||
files_to_copy.append((source_path, rel_path))
|
||
|
||
if files_to_copy:
|
||
self.logger.info(f"[OTA] 解压成功,共 {len(files_to_copy)} 个文件")
|
||
else:
|
||
self.logger.error(f"[OTA] 解压成功但未找到任何文件")
|
||
return False
|
||
else:
|
||
self.logger.error(f"[OTA] 解压失败")
|
||
return False
|
||
else:
|
||
# 单个文件更新:从下载的文件名推断目标文件名
|
||
filename = os.path.basename(downloaded_file)
|
||
|
||
# 如果下载的文件是 main_tmp.py,目标应该是 main.py
|
||
# 如果下载的文件是 main.py,目标也是 main.py
|
||
# 其他文件名,保持原样
|
||
if filename == "main_tmp.py":
|
||
target_rel_path = "main.py"
|
||
else:
|
||
target_rel_path = filename
|
||
|
||
files_to_copy = [(downloaded_file, target_rel_path)]
|
||
self.logger.info(f"[OTA] 单个文件更新: {downloaded_file} -> {target_rel_path}")
|
||
|
||
# 复制文件
|
||
if not files_to_copy:
|
||
self.logger.error(f"[OTA] 没有文件需要复制")
|
||
return False
|
||
|
||
copied_files = []
|
||
for source_path, rel_path in files_to_copy:
|
||
dest_path = os.path.join(config.APP_DIR, rel_path)
|
||
|
||
# 检查源文件和目标文件是否是同一个文件(避免复制到自身)
|
||
if os.path.abspath(source_path) == os.path.abspath(dest_path):
|
||
self.logger.warning(f"[OTA] 源文件和目标文件相同,跳过复制: {rel_path} (文件已在正确位置)")
|
||
copied_files.append(rel_path)
|
||
continue
|
||
|
||
dest_dir = os.path.dirname(dest_path)
|
||
if dest_dir:
|
||
try:
|
||
os.makedirs(dest_dir, exist_ok=True)
|
||
except Exception:
|
||
pass
|
||
|
||
try:
|
||
shutil.copy2(source_path, dest_path)
|
||
copied_files.append(rel_path)
|
||
self.logger.info(f"[OTA] 已复制: {rel_path}")
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 复制 {rel_path} 失败: {e}")
|
||
return False
|
||
|
||
if copied_files:
|
||
self.logger.info(f"[OTA] 成功复制 {len(copied_files)} 个文件到应用目录")
|
||
|
||
# 确保写入磁盘
|
||
try:
|
||
os.sync()
|
||
except:
|
||
pass
|
||
time.sleep_ms(500)
|
||
|
||
# 写入 pending
|
||
try:
|
||
pending_obj = {
|
||
"ts": 0, # MaixPy time 模块没有 time() 函数,使用 0
|
||
"url": ota_url or "",
|
||
"downloaded_file": downloaded_file,
|
||
"was_archive": is_archive,
|
||
"archive_type": archive_type if is_archive else None,
|
||
"backup_dir": backup_dir,
|
||
"updated_files": copied_files,
|
||
"restart_count": 0,
|
||
"max_restarts": 3,
|
||
}
|
||
with open(ota_pending, "w", encoding="utf-8") as f:
|
||
json.dump(pending_obj, f)
|
||
try:
|
||
os.sync()
|
||
except:
|
||
pass
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 写入 ota_pending 失败: {e}")
|
||
|
||
# 通知服务器(延迟导入避免循环导入)
|
||
from network import safe_enqueue
|
||
safe_enqueue({"result": "ota_applied_rebooting", "files": copied_files}, 2)
|
||
time.sleep_ms(1000)
|
||
|
||
# 清理临时解压目录
|
||
if is_archive and 'extract_dir' in locals():
|
||
try:
|
||
if os.path.exists(extract_dir):
|
||
shutil.rmtree(extract_dir, ignore_errors=True)
|
||
self.logger.info(f"[OTA] 已清理临时解压目录: {extract_dir}")
|
||
except Exception as e:
|
||
self.logger.warning(f"[OTA] 清理临时目录失败(可忽略): {e}")
|
||
|
||
# 清理下载文件
|
||
try:
|
||
# 删除下载的文件(可能包含:原始加密包 + 临时明文zip)
|
||
files_to_remove = []
|
||
try:
|
||
if 'downloaded_file_original' in locals() and downloaded_file_original:
|
||
files_to_remove.append(downloaded_file_original)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
if 'decrypted_tmp_zip' in locals() and decrypted_tmp_zip:
|
||
files_to_remove.append(decrypted_tmp_zip)
|
||
except Exception:
|
||
pass
|
||
# 兼容:如果变量不存在,至少清理当前 downloaded_file
|
||
if not files_to_remove:
|
||
files_to_remove = [downloaded_file]
|
||
|
||
removed_any = False
|
||
for fp in list(dict.fromkeys(files_to_remove)):
|
||
try:
|
||
if fp and os.path.exists(fp):
|
||
os.remove(fp)
|
||
removed_any = True
|
||
self.logger.info(f"[OTA] 已删除下载文件: {fp}")
|
||
except Exception as e:
|
||
self.logger.warning(f"[OTA] 删除下载文件失败(可忽略): {e}")
|
||
|
||
# 尝试删除时间戳目录(如果为空)
|
||
try:
|
||
download_dir = os.path.dirname(files_to_remove[0] if files_to_remove else downloaded_file)
|
||
if download_dir.startswith("/tmp/download/"):
|
||
# 检查时间戳目录是否为空
|
||
if os.path.exists(download_dir):
|
||
try:
|
||
files_in_dir = os.listdir(download_dir)
|
||
if not files_in_dir:
|
||
os.rmdir(download_dir)
|
||
self.logger.info(f"[OTA] 已删除空时间戳目录: {download_dir}")
|
||
except Exception as e:
|
||
self.logger.debug(f"[OTA] 删除时间戳目录失败(可忽略): {e}")
|
||
except Exception as e:
|
||
self.logger.debug(f"[OTA] 清理时间戳目录时出错(可忽略): {e}")
|
||
except Exception as e:
|
||
self.logger.warning(f"[OTA] 清理下载文件时出错(可忽略): {e}")
|
||
|
||
# 重启设备
|
||
self.logger.info("[OTA] 准备重启设备...")
|
||
os.system("reboot")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] apply_ota_and_reboot 异常: {e}")
|
||
import traceback
|
||
self.logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def get_download_timestamp_dir(self):
|
||
"""
|
||
获取下载目录(带时间戳),格式:/tmp/download/YYYYMMDD_HHMMSS
|
||
使用时间戳而不是日期,避免跨天问题
|
||
|
||
Returns:
|
||
下载目录路径
|
||
"""
|
||
try:
|
||
# 尝试从系统获取时间戳
|
||
try:
|
||
# 方法1:使用系统 date 命令(精确到秒)
|
||
timestamp_str = os.popen("date +%Y%m%d_%H%M%S 2>/dev/null").read().strip()
|
||
if timestamp_str and len(timestamp_str) == 15: # YYYYMMDD_HHMMSS = 15字符
|
||
timestamp_dir = timestamp_str
|
||
else:
|
||
raise ValueError("date command failed")
|
||
except:
|
||
# 方法2:使用 Python datetime(如果系统时间已同步)
|
||
try:
|
||
from datetime import datetime
|
||
now = datetime.now()
|
||
timestamp_dir = now.strftime("%Y%m%d_%H%M%S")
|
||
except:
|
||
# 方法3:如果都失败,使用默认时间戳
|
||
timestamp_dir = "00000000_000000"
|
||
|
||
download_base = "/tmp/download"
|
||
download_dir = f"{download_base}/{timestamp_dir}"
|
||
|
||
# 确保目录存在
|
||
try:
|
||
os.makedirs(download_dir, exist_ok=True)
|
||
except Exception as e:
|
||
self.logger.warning(f"[OTA] 创建下载目录失败: {e},使用基础目录")
|
||
download_dir = download_base
|
||
try:
|
||
os.makedirs(download_dir, exist_ok=True)
|
||
except:
|
||
pass
|
||
|
||
return download_dir
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 获取下载目录失败: {e},使用默认目录")
|
||
return "/tmp/download"
|
||
|
||
def get_filename_from_url(self, url, default_name="main_tmp"):
|
||
"""
|
||
从URL中提取文件名和扩展名,保存到带时间戳的下载目录
|
||
|
||
Args:
|
||
url: 下载URL
|
||
default_name: 如果无法从URL提取文件名,使用的默认名称
|
||
|
||
Returns:
|
||
完整的文件路径,例如: "/tmp/download/20250108_143025/main.zip"
|
||
"""
|
||
try:
|
||
# 获取下载目录(带时间戳)
|
||
download_dir = self.get_download_timestamp_dir()
|
||
|
||
parsed = urlparse(url)
|
||
path = parsed.path
|
||
filename = os.path.basename(path)
|
||
filename = unquote(filename)
|
||
|
||
# 如果从URL提取到了文件名(无论是否有扩展名),都使用该文件名
|
||
if filename and filename.strip():
|
||
return f"{download_dir}/{filename}"
|
||
else:
|
||
# 只有在完全无法提取文件名时,才使用默认名称
|
||
return f"{download_dir}/{default_name}"
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA] 从URL提取文件名失败: {e},使用默认文件名")
|
||
download_dir = self.get_download_timestamp_dir()
|
||
return f"{download_dir}/{default_name}"
|
||
|
||
def download_file_via_wifi(self, url, filename):
|
||
"""从指定 URL 下载文件,根据文件类型自动选择文本或二进制模式,并支持MD5校验"""
|
||
try:
|
||
self.logger.info(f"正在从 {url} 下载文件...")
|
||
response = requests.get(url)
|
||
response.raise_for_status()
|
||
|
||
# 从响应头中提取MD5(如果服务器提供)
|
||
md5_b64_expected = None
|
||
if 'Content-Md5' in response.headers:
|
||
md5_b64_expected = response.headers['Content-Md5'].strip()
|
||
self.logger.info(f"[DOWNLOAD] 服务器提供了MD5校验值: {md5_b64_expected}")
|
||
|
||
# 根据文件扩展名判断是否为二进制文件
|
||
filename_lower = filename.lower()
|
||
is_binary = filename_lower.endswith(('.zip', '.zip.enc', '.enc', '.bin', '.tar', '.gz', '.exe', '.dll', '.so', '.dylib'))
|
||
|
||
if is_binary:
|
||
# 二进制文件:使用二进制模式写入
|
||
data = response.content
|
||
with open(filename, 'wb') as file:
|
||
file.write(data)
|
||
# 强制刷新到磁盘
|
||
try:
|
||
os.sync()
|
||
except:
|
||
pass
|
||
self.logger.info(f"[DOWNLOAD] 使用二进制模式下载: {filename}, 大小: {len(data)} bytes")
|
||
else:
|
||
# 文本文件:使用文本模式写入
|
||
response.encoding = 'utf-8'
|
||
with open(filename, 'w', encoding='utf-8') as file:
|
||
file.write(response.text)
|
||
|
||
self.logger.info(f"[DOWNLOAD] 使用文本模式下载: {filename}")
|
||
|
||
# MD5 校验(如果服务器提供了MD5值)
|
||
if md5_b64_expected and hashlib is not None:
|
||
try:
|
||
with open(filename, "rb") as f:
|
||
file_data = f.read()
|
||
digest = hashlib.md5(file_data).digest()
|
||
md5_b64_got = binascii.b2a_base64(digest).decode().strip()
|
||
|
||
if md5_b64_got != md5_b64_expected:
|
||
self.logger.error(f"[DOWNLOAD] MD5校验失败: 期望={md5_b64_expected}, 实际={md5_b64_got}")
|
||
return f"下载失败!MD5校验失败: 期望={md5_b64_expected}, 实际={md5_b64_got}"
|
||
else:
|
||
self.logger.info(f"[DOWNLOAD] MD5校验通过: {md5_b64_got}")
|
||
except Exception as e:
|
||
self.logger.warning(f"[DOWNLOAD] MD5校验过程出错: {e}")
|
||
# MD5校验出错时,如果是二进制文件(特别是ZIP),应该失败
|
||
if is_binary:
|
||
return f"下载失败!MD5校验异常: {e}"
|
||
elif is_binary and not md5_b64_expected:
|
||
# 二进制文件(特别是ZIP)建议有MD5校验
|
||
self.logger.warning(f"[DOWNLOAD] 警告: 服务器未提供MD5校验值,无法验证文件完整性")
|
||
|
||
return f"下载成功!文件已保存为: {filename}"
|
||
except requests.exceptions.RequestException as e:
|
||
return f"下载失败!网络请求错误: {e}"
|
||
except OSError as e:
|
||
return f"下载失败!文件写入错误: {e}"
|
||
except Exception as e:
|
||
return f"下载失败!发生未知错误: {e}"
|
||
|
||
# def direct_ota_download(self, ota_url):
|
||
# """直接执行 OTA 下载(假设已有网络)"""
|
||
|
||
# self._set_ota_url(ota_url)
|
||
# self._start_update_thread()
|
||
|
||
# try:
|
||
# if not ota_url:
|
||
# from network import safe_enqueue
|
||
# safe_enqueue({"result": "ota_failed", "reason": "missing_url"}, 2)
|
||
# return
|
||
|
||
# parsed_url = urlparse(ota_url)
|
||
# host = parsed_url.hostname
|
||
# port = parsed_url.port or (443 if parsed_url.scheme == 'https' else 80)
|
||
|
||
# if not network_manager.is_server_reachable(host, port, timeout=8):
|
||
# from network import safe_enqueue
|
||
# safe_enqueue({"result": "ota_failed", "reason": f"无法连接 {host}:{port}"}, 2)
|
||
# return
|
||
|
||
# downloaded_filename = self.get_filename_from_url(ota_url, default_name="main_tmp")
|
||
# self.logger.info(f"[OTA] 下载文件将保存为: {downloaded_filename}")
|
||
# self.logger.info(f"[OTA] 开始下载: {ota_url}")
|
||
# result_msg = self.download_file(ota_url, downloaded_filename)
|
||
# self.logger.info(f"[OTA] {result_msg}")
|
||
|
||
# if "成功" in result_msg or "下载成功" in result_msg:
|
||
# if self.apply_ota_and_reboot(ota_url, downloaded_filename):
|
||
# return
|
||
# else:
|
||
# from network import safe_enqueue
|
||
# safe_enqueue({"result": result_msg}, 2)
|
||
|
||
# except Exception as e:
|
||
# error_msg = f"OTA 异常: {str(e)}"
|
||
# self.logger.error(error_msg)
|
||
# from network import safe_enqueue
|
||
# safe_enqueue({"result": "ota_failed", "reason": error_msg}, 2)
|
||
# finally:
|
||
# self._stop_update_thread()
|
||
|
||
def download_file_via_4g(self, url, filename,
|
||
total_timeout_ms=600000,
|
||
retries=3,
|
||
debug=False):
|
||
"""
|
||
ML307R HTTP 下载(更稳的"固定小块 Range 顺序下载",基于main109.py):
|
||
- 只依赖 +MHTTPURC:"header"/"content"(不依赖 MHTTPREAD/cached)
|
||
- 每次只请求一个小块 Range(默认 10240B),失败就重试同一块,必要时缩小块大小
|
||
- 每个 chunk 都重新 MHTTPCREATE/MHTTPREQUEST,避免卡在"206 header 但不吐 content"的坏状态
|
||
- 使用二进制模式下载,确保文件完整性
|
||
"""
|
||
from urllib.parse import urlparse
|
||
from hardware import hardware_manager
|
||
|
||
# 小块策略(与main109.py保持一致)
|
||
CHUNK_MAX = 10240
|
||
CHUNK_MIN = 128
|
||
CHUNK_RETRIES = 12
|
||
FRAG_SIZE = 1024
|
||
FRAG_DELAY = 10
|
||
|
||
t_func0 = time.ticks_ms()
|
||
|
||
parsed = urlparse(url)
|
||
host = parsed.hostname
|
||
path = parsed.path or "/"
|
||
if not host:
|
||
return False, "bad_url (no host)"
|
||
|
||
# 很多 ML307R 的 MHTTP 对 https 不稳定;对已知域名做降级
|
||
|
||
if isinstance(url, str) and url.startswith("https://static.shelingxingqiu.com/"):
|
||
base_url = "https://static.shelingxingqiu.com"
|
||
# TODO:使用https,看看是否能成功
|
||
self._is_https = True
|
||
else:
|
||
base_url = f"http://{host}"
|
||
self._is_https = False
|
||
# logger removed - use self.logger instead
|
||
|
||
def _log(*a):
|
||
if debug:
|
||
self.logger.debug(" ".join(str(x) for x in a))
|
||
|
||
def _pwr_log(prefix=""):
|
||
"""debug 用:输出电压/电量"""
|
||
if not debug:
|
||
return
|
||
try:
|
||
v = get_bus_voltage()
|
||
p = voltage_to_percent(v)
|
||
self.logger.debug(f"[PWR]{prefix} v={v:.3f}V p={p}%")
|
||
except Exception as e:
|
||
try:
|
||
self.logger.debug(f"[PWR]{prefix} read_failed: {e}")
|
||
except:
|
||
pass
|
||
|
||
def _clear_http_events():
|
||
if hardware_manager.at_client:
|
||
while hardware_manager.at_client.pop_http_event() is not None:
|
||
pass
|
||
|
||
def _parse_httpid(resp: str):
|
||
m = re.search(r"\+MHTTPCREATE:\s*(\d+)", resp)
|
||
return int(m.group(1)) if m else None
|
||
|
||
def _get_ip():
|
||
r = hardware_manager.at_client.send("AT+CGPADDR=1", "OK", 3000)
|
||
m = re.search(r'\+CGPADDR:\s*1,"([^"]+)"', r)
|
||
return m.group(1) if m else ""
|
||
|
||
def _ensure_pdp():
|
||
ip = _get_ip()
|
||
if ip and ip != "0.0.0.0":
|
||
return True, ip
|
||
hardware_manager.at_client.send("AT+MIPCALL=1,1", "OK", 15000)
|
||
for _ in range(10):
|
||
ip = _get_ip()
|
||
if ip and ip != "0.0.0.0":
|
||
return True, ip
|
||
time.sleep(1)
|
||
return False, ip
|
||
|
||
def _extract_hdr_fields(hdr_text: str):
|
||
mlen = re.search(r"Content-Length:\s*(\d+)", hdr_text, re.IGNORECASE)
|
||
clen = int(mlen.group(1)) if mlen else None
|
||
mmd5 = re.search(r"Content-Md5:\s*([A-Za-z0-9+/=]+)", hdr_text, re.IGNORECASE)
|
||
md5_b64 = mmd5.group(1).strip() if mmd5 else None
|
||
return clen, md5_b64
|
||
|
||
def _extract_content_range(hdr_text: str):
|
||
m = re.search(r"Content-Range:\s*bytes\s*(\d+)\s*-\s*(\d+)\s*/\s*(\d+)", hdr_text, re.IGNORECASE)
|
||
if not m:
|
||
return None, None, None
|
||
try:
|
||
return int(m.group(1)), int(m.group(2)), int(m.group(3))
|
||
except:
|
||
return None, None, None
|
||
|
||
def _hard_reset_http():
|
||
"""模块进入"坏状态"时的保守清场"""
|
||
_clear_http_events()
|
||
for i in range(0, 6):
|
||
try:
|
||
hardware_manager.at_client.send(f"AT+MHTTPDEL={i}", "OK", 1200)
|
||
except:
|
||
pass
|
||
_clear_http_events()
|
||
|
||
def _create_httpid(full_reset=False):
|
||
_clear_http_events()
|
||
if hardware_manager.at_client:
|
||
hardware_manager.at_client.flush()
|
||
if full_reset:
|
||
_hard_reset_http()
|
||
resp = hardware_manager.at_client.send(f'AT+MHTTPCREATE="{base_url}"', "OK", 8000)
|
||
hid = _parse_httpid(resp)
|
||
if self._is_https:
|
||
resp = hardware_manager.at_client.send(f'AT+MHTTPCFG="ssl",{hid},1,1', "OK", 2000)
|
||
if "ERROR" in resp or "CME ERROR" in resp:
|
||
self.logger.error(f"MHTTPCFG SSL failed: {resp}")
|
||
# 尝试https 降级到http
|
||
downgraded_base_url = base_url.replace("https://", "http://")
|
||
resp = hardware_manager.at_client.send(f'AT+MHTTPCREATE="{downgraded_base_url}"', "OK", 8000)
|
||
hid = _parse_httpid(resp)
|
||
|
||
return hid, resp
|
||
|
||
def _fetch_range_into_buf(start, want_len, out_buf, full_reset=False):
|
||
"""
|
||
请求 Range [start, start+want_len),写入 out_buf(bytearray,长度=want_len)
|
||
返回 (ok, msg, total_len, md5_b64, got_len)
|
||
"""
|
||
end_incl = start + want_len - 1
|
||
hid, cresp = _create_httpid(full_reset=full_reset)
|
||
if hid is None:
|
||
return False, f"MHTTPCREATE failed: {cresp}", None, None, 0
|
||
|
||
# 降低 URC 压力(分片/延迟)
|
||
hardware_manager.at_client.send(f'AT+MHTTPCFG="fragment",{hid},{FRAG_SIZE},{FRAG_DELAY}', "OK", 1500)
|
||
# 设置 Range header(inclusive)
|
||
hardware_manager.at_client.send(f'AT+MHTTPCFG="header",{hid},"Range: bytes={start}-{end_incl}"', "OK", 3000)
|
||
|
||
req = hardware_manager.at_client.send(f'AT+MHTTPREQUEST={hid},1,0,"{path}"', "OK", 15000)
|
||
if "ERROR" in req or "CME ERROR" in req:
|
||
hardware_manager.at_client.send(f"AT+MHTTPDEL={hid}", "OK", 2000)
|
||
return False, f"MHTTPREQUEST failed: {req}", None, None, 0
|
||
|
||
# 等 header + content
|
||
hdr_text = None
|
||
hdr_accum = ""
|
||
code = None
|
||
resp_total = None
|
||
total_len = None
|
||
md5_b64 = None
|
||
|
||
got_ranges = set()
|
||
last_sum = 0
|
||
t0 = time.ticks_ms()
|
||
timeout_ms = 9000
|
||
logged_hdr = False
|
||
|
||
while time.ticks_ms() - t0 < timeout_ms:
|
||
ev = hardware_manager.at_client.pop_http_event() if hardware_manager.at_client else None
|
||
if not ev:
|
||
time.sleep_ms(5)
|
||
continue
|
||
|
||
if ev[0] == "header":
|
||
_, ehid, ecode, ehdr = ev
|
||
if ehid != hid:
|
||
continue
|
||
code = ecode
|
||
hdr_text = ehdr
|
||
if ehdr:
|
||
hdr_accum = (hdr_accum + "\n" + ehdr) if hdr_accum else ehdr
|
||
|
||
resp_total_tmp, md5_tmp = _extract_hdr_fields(hdr_accum)
|
||
if md5_tmp:
|
||
md5_b64 = md5_tmp
|
||
cr_s, cr_e, cr_total = _extract_content_range(hdr_accum)
|
||
if cr_total is not None:
|
||
total_len = cr_total
|
||
if resp_total_tmp is not None:
|
||
resp_total = resp_total_tmp
|
||
elif resp_total is None and (cr_s is not None) and (cr_e is not None) and (cr_e >= cr_s):
|
||
resp_total = (cr_e - cr_s + 1)
|
||
if (not logged_hdr) and (resp_total is not None or total_len is not None):
|
||
_log(f"[HDR] id={hid} code={code} clen={resp_total} cr={cr_s}-{cr_e}/{cr_total}")
|
||
logged_hdr = True
|
||
continue
|
||
|
||
if ev[0] == "content":
|
||
_, ehid, _total, _sum, _cur, payload = ev
|
||
if ehid != hid:
|
||
continue
|
||
if resp_total is None:
|
||
resp_total = _total
|
||
if resp_total is None or resp_total <= 0:
|
||
continue
|
||
start_rel = _sum - _cur
|
||
end_rel = _sum
|
||
if start_rel < 0 or start_rel >= resp_total:
|
||
continue
|
||
if end_rel > resp_total:
|
||
end_rel = resp_total
|
||
actual_len = min(len(payload), end_rel - start_rel)
|
||
if actual_len <= 0:
|
||
continue
|
||
out_buf[start_rel:start_rel + actual_len] = payload[:actual_len]
|
||
got_ranges.add((start_rel, start_rel + actual_len))
|
||
if _sum > last_sum:
|
||
last_sum = _sum
|
||
if debug and (last_sum >= resp_total or (last_sum % 512 == 0)):
|
||
_log(f"[CHUNK] {start}+{last_sum}/{resp_total}")
|
||
|
||
if last_sum >= resp_total:
|
||
break
|
||
|
||
# 清理实例(快路径:只删当前 hid)
|
||
try:
|
||
hardware_manager.at_client.send(f"AT+MHTTPDEL={hid}", "OK", 2000)
|
||
except:
|
||
pass
|
||
|
||
if resp_total is None:
|
||
return False, "no_header_or_total", total_len, md5_b64, 0
|
||
|
||
# 计算实际填充长度
|
||
merged = sorted(got_ranges)
|
||
merged2 = []
|
||
for s, e in merged:
|
||
if not merged2 or s > merged2[-1][1]:
|
||
merged2.append((s, e))
|
||
else:
|
||
merged2[-1] = (merged2[-1][0], max(merged2[-1][1], e))
|
||
filled = sum(e - s for s, e in merged2)
|
||
|
||
if filled < resp_total:
|
||
return False, f"incomplete_chunk got={filled} expected={resp_total} code={code}", total_len, md5_b64, filled
|
||
|
||
got_len = resp_total
|
||
return True, "OK", total_len, md5_b64, got_len
|
||
|
||
try:
|
||
self._begin_ota()
|
||
except:
|
||
pass
|
||
|
||
from network import network_manager
|
||
with network_manager.get_uart_lock():
|
||
try:
|
||
ok_pdp, ip = _ensure_pdp()
|
||
if not ok_pdp:
|
||
return False, f"PDP not ready (ip={ip})"
|
||
|
||
# 先清空旧事件,避免串台
|
||
_clear_http_events()
|
||
|
||
# 为了支持随机写入,先创建空文件
|
||
try:
|
||
with open(filename, "wb") as f:
|
||
f.write(b"")
|
||
except Exception as e:
|
||
return False, f"open_file_failed: {e}"
|
||
|
||
total_len = None
|
||
expect_md5_b64 = None
|
||
|
||
offset = 0
|
||
chunk = CHUNK_MAX
|
||
t_start = time.ticks_ms()
|
||
last_progress_ms = t_start
|
||
STALL_TIMEOUT_MS = 60000
|
||
last_pwr_ms = t_start
|
||
_pwr_log(prefix=" ota_start")
|
||
bad_http_state = 0
|
||
|
||
while True:
|
||
now = time.ticks_ms()
|
||
if debug and time.ticks_diff(now, last_pwr_ms) >= 5000:
|
||
last_pwr_ms = now
|
||
_pwr_log(prefix=f" off={offset}/{total_len or '?'}")
|
||
if time.ticks_diff(now, t_start) > total_timeout_ms:
|
||
return False, f"timeout overall after {total_timeout_ms}ms offset={offset} total={total_len}"
|
||
|
||
if time.ticks_diff(now, last_progress_ms) > STALL_TIMEOUT_MS:
|
||
return False, f"timeout stalled {STALL_TIMEOUT_MS}ms offset={offset} total={total_len}"
|
||
|
||
if total_len is not None and offset >= total_len:
|
||
break
|
||
|
||
want = chunk
|
||
if total_len is not None:
|
||
remain = total_len - offset
|
||
if remain <= 0:
|
||
break
|
||
if want > remain:
|
||
want = remain
|
||
|
||
# 本 chunk 的 buffer(长度=want)
|
||
buf = bytearray(want)
|
||
|
||
success = False
|
||
last_err = "unknown"
|
||
md5_seen = None
|
||
got_len = 0
|
||
for k in range(1, CHUNK_RETRIES + 1):
|
||
do_full_reset = (bad_http_state >= 2)
|
||
ok, msg, tlen, md5_b64, got = _fetch_range_into_buf(offset, want, buf, full_reset=do_full_reset)
|
||
last_err = msg
|
||
if tlen is not None and total_len is None:
|
||
total_len = tlen
|
||
if md5_b64 and not expect_md5_b64:
|
||
expect_md5_b64 = md5_b64
|
||
if ok:
|
||
success = True
|
||
got_len = got
|
||
bad_http_state = 0
|
||
break
|
||
|
||
try:
|
||
if ("no_header_or_total" in msg) or ("MHTTPREQUEST failed" in msg) or ("MHTTPCREATE failed" in msg):
|
||
bad_http_state += 1
|
||
else:
|
||
bad_http_state = max(0, bad_http_state - 1)
|
||
except:
|
||
pass
|
||
|
||
if chunk > CHUNK_MIN:
|
||
chunk = max(CHUNK_MIN, chunk // 2)
|
||
want = min(chunk, want)
|
||
buf = bytearray(want)
|
||
_log(f"[RETRY] off={offset} want={want} try={k} err={msg}")
|
||
_pwr_log(prefix=f" retry{k} off={offset}")
|
||
time.sleep_ms(120)
|
||
|
||
if not success:
|
||
return False, f"chunk_failed off={offset} want={want} err={last_err} total={total_len}"
|
||
|
||
# 写入文件(二进制模式)
|
||
try:
|
||
with open(filename, "r+b") as f:
|
||
f.seek(offset)
|
||
f.write(bytes(buf))
|
||
except Exception as e:
|
||
return False, f"write_failed off={offset}: {e}"
|
||
|
||
offset += len(buf)
|
||
last_progress_ms = time.ticks_ms()
|
||
chunk = CHUNK_MAX
|
||
if debug:
|
||
_log(f"[OK] offset={offset}/{total_len or '?'}")
|
||
|
||
# MD5 校验
|
||
if expect_md5_b64 and hashlib is not None:
|
||
try:
|
||
with open(filename, "rb") as f:
|
||
data = f.read()
|
||
digest = hashlib.md5(data).digest()
|
||
got_b64 = binascii.b2a_base64(digest).decode().strip()
|
||
if got_b64 != expect_md5_b64:
|
||
return False, f"md5_mismatch got={got_b64} expected={expect_md5_b64}"
|
||
self.logger.debug(f"[4G-DL] MD5 verified: {got_b64}")
|
||
except Exception as e:
|
||
return False, f"md5_check_failed: {e}"
|
||
|
||
t_cost = time.ticks_diff(time.ticks_ms(), t_func0)
|
||
self.logger.info(f"[4G-DL] download complete: size={offset} ip={ip} cost_ms={t_cost}")
|
||
return True, f"OK size={offset} ip={ip} cost_ms={t_cost}"
|
||
|
||
finally:
|
||
self._end_ota()
|
||
|
||
def direct_ota_download_via_4g(self, ota_url):
|
||
"""通过 4G 模块下载 OTA(不需要 Wi-Fi)"""
|
||
self._set_ota_url(ota_url)
|
||
self._set_ota_mode("4g")
|
||
self._start_update_thread()
|
||
# 延迟导入避免循环依赖
|
||
from network import safe_enqueue
|
||
|
||
try:
|
||
t_ota0 = time.ticks_ms()
|
||
if not ota_url:
|
||
safe_enqueue({"result": "ota_failed", "reason": "missing_url"}, 2)
|
||
return
|
||
|
||
# OTA 全程暂停 TCP(避免心跳/重连抢占 uart4g_lock,导致 server 断链 + HTTP URC 更容易丢)
|
||
self._begin_ota()
|
||
|
||
# 主动断开 AT TCP,减少 +MIPURC 噪声干扰 HTTP URC 下载
|
||
from network import network_manager
|
||
network_manager.disconnect_server()
|
||
try:
|
||
with network_manager.get_uart_lock():
|
||
hardware_manager.at_client.send("AT+MIPCLOSE=0", "OK", 1500)
|
||
except:
|
||
pass
|
||
|
||
# 从URL中提取文件名(保留原始扩展名)
|
||
downloaded_filename = self.get_filename_from_url(ota_url, default_name="main_tmp")
|
||
self.logger.info(f"[OTA-4G] 下载文件将保存为: {downloaded_filename}")
|
||
|
||
self.logger.info(f"[OTA-4G] 开始通过 4G 下载: {ota_url}")
|
||
# 重要说明:
|
||
# - AT+MDIALUP / RNDIS 是"USB 主机拨号上网"模式,在不少 ML307R 固件上会占用/切换内部网络栈,
|
||
# 从而导致 AT+MIPOPEN / +MIPURC 这套 TCP 连接无法工作(你会看到一直"连接到服务器...")。
|
||
# - 这个设备当前 4G 是走 UART + AT Socket(MIPOPEN),并没有把 4G 变成系统网卡(如 ppp0)。
|
||
# 因此这里不再自动拨号/改路由;只有当系统本来就有 default route(例如 eth0 已联网)时,才尝试走 requests 下载。
|
||
|
||
msg_sys = ""
|
||
try:
|
||
import power
|
||
v = power.get_bus_voltage()
|
||
p = power.voltage_to_percent(v)
|
||
self.logger.info(f"[OTA-4G][PWR] before_urc v={v:.3f}V p={p}%")
|
||
except Exception as e:
|
||
self.logger.error(f"[OTA-4G][PWR] before_urc read_failed: {e}")
|
||
|
||
t_dl0 = time.ticks_ms()
|
||
success, msg = self.download_file_via_4g(ota_url, downloaded_filename, debug=False)
|
||
t_dl_cost = time.ticks_diff(t_dl0, time.ticks_ms())
|
||
self.logger.info(f"[OTA-4G] {msg}")
|
||
self.logger.info(f"[OTA-4G] download_cost_ms={t_dl_cost}")
|
||
|
||
if success and "OK" in msg:
|
||
if self.apply_ota_and_reboot(ota_url, downloaded_filename):
|
||
return
|
||
else:
|
||
safe_enqueue({"result": msg_sys or msg}, 2)
|
||
|
||
except Exception as e:
|
||
error_msg = f"OTA-4G 异常: {str(e)}"
|
||
self.logger.error(error_msg)
|
||
safe_enqueue({"result": "ota_failed", "reason": error_msg}, 2)
|
||
finally:
|
||
# 总耗时(注意:若成功并 reboot,这行可能来不及打印)
|
||
try:
|
||
t_cost = time.ticks_diff(time.ticks_ms(), t_ota0)
|
||
self.logger.info(f"[OTA-4G] total_cost_ms={t_cost}")
|
||
except:
|
||
pass
|
||
self._stop_update_thread()
|
||
# 对应上面的 _begin_ota()
|
||
self._end_ota()
|
||
|
||
def handle_wifi_and_update(self, ssid, password, ota_url):
|
||
"""在子线程中执行 Wi-Fi 连接 + OTA 更新流程"""
|
||
self._set_ota_url(ota_url)
|
||
self._set_ota_mode("wifi")
|
||
self._start_update_thread()
|
||
# 延迟导入避免循环导入
|
||
from network import network_manager, safe_enqueue
|
||
|
||
try:
|
||
# 与 4G 一致:OTA 期间暂停主循环 / 心跳等
|
||
self._begin_ota()
|
||
if not ota_url:
|
||
safe_enqueue({"result": "ota_failed", "reason": "missing_url"}, 2)
|
||
return
|
||
from urllib.parse import urlparse
|
||
parsed_url = urlparse(ota_url)
|
||
host = parsed_url.hostname
|
||
port = parsed_url.port or (443 if parsed_url.scheme == 'https' else 80)
|
||
|
||
# 先连接 WiFi,并把 OTA host:port 作为“可达性验证目标”
|
||
# 只有连接成功 + 可访问 OTA 地址,才会把 SSID/PASS 落盘到 /boot/
|
||
ip, error = network_manager.connect_wifi(
|
||
ssid,
|
||
password,
|
||
verify_host=host,
|
||
verify_port=port,
|
||
persist=True,
|
||
)
|
||
if error:
|
||
safe_enqueue({"result": "wifi_failed", "error": error}, 2)
|
||
return
|
||
safe_enqueue({"result": "wifi_connected", "ip": ip}, 2)
|
||
|
||
downloaded_filename = self.get_filename_from_url(ota_url, default_name="main_tmp")
|
||
self.logger.info(f"[OTA] 下载文件将保存为: {downloaded_filename}")
|
||
|
||
self.logger.info(f"[NET] 已确认可访问 {host}:{port},开始下载...")
|
||
result = self.download_file_via_wifi(ota_url, downloaded_filename)
|
||
self.logger.info(result)
|
||
|
||
if "成功" in result or "下载成功" in result:
|
||
if self.apply_ota_and_reboot(ota_url, downloaded_filename):
|
||
return
|
||
else:
|
||
safe_enqueue({"result": result}, 2)
|
||
except Exception as e:
|
||
err_msg = f"下载失败: {str(e)}"
|
||
safe_enqueue({"result": err_msg}, 2)
|
||
self.logger.error(err_msg)
|
||
finally:
|
||
self._stop_update_thread()
|
||
self._end_ota() # 与 4G 一致
|
||
print("[UPDATE] 更新线程执行完毕,即将退出。")
|
||
|
||
def restore_from_backup(self, backup_dir_path=None):
|
||
"""
|
||
从备份目录恢复所有文件到应用目录
|
||
|
||
Args:
|
||
backup_dir_path: 备份目录路径,如果为None,自动查找最新的备份目录
|
||
|
||
Returns:
|
||
bool: 是否成功恢复
|
||
"""
|
||
backup_base = config.BACKUP_BASE
|
||
|
||
try:
|
||
if backup_dir_path is None:
|
||
if not os.path.exists(backup_base):
|
||
self.logger.error(f"[RESTORE] 备份目录不存在: {backup_base}")
|
||
return False
|
||
|
||
backup_dirs = []
|
||
for item in os.listdir(backup_base):
|
||
if item == ".counter":
|
||
continue
|
||
item_path = os.path.join(backup_base, item)
|
||
if os.path.isdir(item_path) and item.startswith("backup_"):
|
||
try:
|
||
dir_num_str = item.replace("backup_", "")
|
||
dir_num = int(dir_num_str)
|
||
backup_dirs.append((item, dir_num))
|
||
except:
|
||
pass
|
||
|
||
if not backup_dirs:
|
||
self.logger.error(f"[RESTORE] 没有找到备份目录")
|
||
return False
|
||
|
||
backup_dirs.sort(key=lambda x: x[1], reverse=True)
|
||
latest_backup = backup_dirs[0][0]
|
||
backup_dir_path = os.path.join(backup_base, latest_backup)
|
||
|
||
if not os.path.exists(backup_dir_path):
|
||
self.logger.error(f"[RESTORE] 备份目录不存在: {backup_dir_path}")
|
||
return False
|
||
|
||
self.logger.info(f"[RESTORE] 开始从备份恢复: {backup_dir_path}")
|
||
|
||
restored_files = []
|
||
for root, dirs, files in os.walk(backup_dir_path):
|
||
for f in files:
|
||
source_path = os.path.join(root, f)
|
||
rel_path = os.path.relpath(source_path, backup_dir_path)
|
||
dest_path = os.path.join(config.APP_DIR, rel_path)
|
||
|
||
dest_dir = os.path.dirname(dest_path)
|
||
if dest_dir:
|
||
os.makedirs(dest_dir, exist_ok=True)
|
||
|
||
try:
|
||
shutil.copy2(source_path, dest_path)
|
||
restored_files.append(rel_path)
|
||
self.logger.info(f"[RESTORE] 已恢复: {rel_path}")
|
||
except Exception as e:
|
||
self.logger.error(f"[RESTORE] 恢复 {rel_path} 失败: {e}")
|
||
|
||
if restored_files:
|
||
self.logger.info(f"[RESTORE] 成功恢复 {len(restored_files)} 个文件")
|
||
return True
|
||
else:
|
||
self.logger.info(f"[RESTORE] 没有文件被恢复")
|
||
return False
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"[RESTORE] 恢复过程出错: {e}")
|
||
return False
|
||
|
||
|
||
# 创建全局单例实例
|
||
ota_manager = OTAManager()
|
||
|
||
# ==================== 向后兼容的函数接口 ====================
|
||
# 这些函数会更新 ota_manager 的状态,并调用实际实现
|
||
|
||
def apply_ota_and_reboot(ota_url=None, downloaded_file=None):
|
||
"""应用OTA并重启(向后兼容接口)"""
|
||
return ota_manager.apply_ota_and_reboot(ota_url, downloaded_file)
|
||
|
||
def direct_ota_download(ota_url):
|
||
"""直接执行OTA下载(向后兼容接口)"""
|
||
return ota_manager.direct_ota_download(ota_url)
|
||
|
||
def direct_ota_download_via_4g(ota_url):
|
||
"""通过4G模块下载OTA(向后兼容接口)"""
|
||
return ota_manager.direct_ota_download_via_4g(ota_url)
|
||
|
||
def handle_wifi_and_update(ssid, password, ota_url):
|
||
"""处理WiFi连接并更新(向后兼容接口)"""
|
||
return ota_manager.handle_wifi_and_update(ssid, password, ota_url)
|
||
|
||
def restore_from_backup(backup_dir_path=None):
|
||
"""从备份恢复(向后兼容接口)"""
|
||
return ota_manager.restore_from_backup(backup_dir_path)
|
||
|