1+ import os
2+ import hashlib
3+ import socket
4+ import logging
5+ from typing import Optional , Callable
6+
7+
8+ logger = logging .getLogger (__name__ )
9+
10+
11+ def _file_sha256 (filename : str ):
12+ """Calculate SHA256 hash of a file."""
13+ sha256_hash = hashlib .sha256 ()
14+
15+ with open (filename , "rb" ) as f :
16+ for byte_block in iter (lambda : f .read (4096 ), b"" ):
17+ sha256_hash .update (byte_block )
18+
19+ return sha256_hash
20+
21+ class ESP32WiFiOTA :
22+ """ESP32 WiFi Unified OTA updates."""
23+
24+ def __init__ (self , filename : str , hostname : str , port : int = 3232 ):
25+ self ._filename = filename
26+ self ._hostname = hostname
27+ self ._port = port
28+ self ._socket : Optional [socket .socket ] = None
29+
30+ if not os .path .exists (self ._filename ):
31+ raise Exception (f"File { self ._filename } does not exist" )
32+
33+ self ._file_hash = _file_sha256 (self ._filename )
34+
35+ def _read_line (self ) -> str :
36+ """Read a line from the socket."""
37+ if not self ._socket :
38+ raise Exception ("Socket not connected" )
39+
40+ line = b""
41+ while not line .endswith (b"\n " ):
42+ char = self ._socket .recv (1 )
43+
44+ if not char :
45+ raise Exception ("Connection closed while waiting for response" )
46+
47+ line += char
48+
49+ return line .decode ("utf-8" ).strip ()
50+
51+ def hash_bytes (self ) -> bytes :
52+ """Return the hash as bytes."""
53+ return self ._file_hash .digest ()
54+
55+ def hash_hex (self ) -> str :
56+ """Return the hash as a hex string."""
57+ return self ._file_hash .hexdigest ()
58+
59+ def update (self , progress_callback : Optional [Callable [[int , int ], None ]] = None ):
60+ """Perform the OTA update."""
61+ with open (self ._filename , "rb" ) as f :
62+ data = f .read ()
63+ size = len (data )
64+
65+ logger .info (f"Starting OTA update with { self ._filename } ({ size } bytes, hash { self .hash_hex ()} )" )
66+
67+ self ._socket = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
68+ self ._socket .settimeout (15 )
69+ try :
70+ self ._socket .connect ((self ._hostname , self ._port ))
71+ logger .debug (f"Connected to { self ._hostname } :{ self ._port } " )
72+
73+ # Send start command
74+ self ._socket .sendall (f"OTA { size } { self .hash_hex ()} \n " .encode ("utf-8" ))
75+
76+ # Wait for OK from the device
77+ while True :
78+ response = self ._read_line ()
79+ if response == "OK" :
80+ break
81+ elif response == "ERASING" :
82+ logger .info ("Device is erasing flash..." )
83+ elif response .startswith ("ERR " ):
84+ raise Exception (f"Device reported error: { response } " )
85+ else :
86+ logger .warning (f"Unexpected response: { response } " )
87+
88+ # Stream firmware
89+ sent_bytes = 0
90+ chunk_size = 1024
91+ while sent_bytes < size :
92+ chunk = data [sent_bytes : sent_bytes + chunk_size ]
93+ self ._socket .sendall (chunk )
94+ sent_bytes += len (chunk )
95+
96+ if progress_callback :
97+ progress_callback (sent_bytes , size )
98+ else :
99+ print (f"[{ sent_bytes / size * 100 :5.1f} %] Sent { sent_bytes } of { size } bytes..." , end = "\r " )
100+
101+ if not progress_callback :
102+ print ()
103+
104+ # Wait for OK from device
105+ logger .info ("Firmware sent, waiting for verification..." )
106+ while True :
107+ response = self ._read_line ()
108+
109+ if response == "OK" :
110+ logger .info ("OTA update completed successfully!" )
111+ break
112+ elif response == "ACK" :
113+ continue
114+ elif response .startswith ("ERR " ):
115+ raise Exception (f"OTA update failed: { response } " )
116+ else :
117+ logger .warning (f"Unexpected final response: { response } " )
118+
119+ finally :
120+ if self ._socket :
121+ self ._socket .close ()
122+ self ._socket = None
0 commit comments