aboutsummaryrefslogtreecommitdiff
path: root/ports/stm32/mboot/mboot.py
blob: 39ae0f6f2db910b623ce395f5351d95f01fb0265 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Driver for Mboot, the MicroPython boot loader
# MIT license; Copyright (c) 2018 Damien P. George

import struct, time, os, hashlib


I2C_CMD_ECHO = 1
I2C_CMD_GETID = 2
I2C_CMD_GETCAPS = 3
I2C_CMD_RESET = 4
I2C_CMD_CONFIG = 5
I2C_CMD_GETLAYOUT = 6
I2C_CMD_MASSERASE = 7
I2C_CMD_PAGEERASE = 8
I2C_CMD_SETRDADDR = 9
I2C_CMD_SETWRADDR = 10
I2C_CMD_READ = 11
I2C_CMD_WRITE = 12
I2C_CMD_COPY = 13
I2C_CMD_CALCHASH = 14
I2C_CMD_MARKVALID = 15


class Bootloader:
    def __init__(self, i2c, addr):
        self.i2c = i2c
        self.addr = addr
        self.buf1 = bytearray(1)
        try:
            self.i2c.writeto(addr, b'')
        except OSError:
            raise Exception('no I2C mboot device found')

    def wait_response(self):
        start = time.ticks_ms()
        while 1:
            try:
                self.i2c.readfrom_into(self.addr, self.buf1)
                n = self.buf1[0]
                break
            except OSError as er:
                time.sleep_us(500)
            if time.ticks_diff(time.ticks_ms(), start) > 5000:
                raise Exception('timeout')
        if n >= 129:
            raise Exception(n)
        if n == 0:
            return b''
        else:
            return self.i2c.readfrom(self.addr, n)

    def wait_empty_response(self):
        ret = self.wait_response()
        if ret:
            raise Exception('expected empty response got %r' % ret)
        else:
            return None

    def echo(self, data):
        self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_ECHO) + data)
        return self.wait_response()

    def getid(self):
        self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_GETID))
        ret = self.wait_response()
        unique_id = ret[:12]
        mcu_name, board_name = ret[12:].split(b'\x00')
        return unique_id, str(mcu_name, 'ascii'), str(board_name, 'ascii')

    def reset(self):
        self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_RESET))
        # we don't expect any response

    def getlayout(self):
        self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_GETLAYOUT))
        layout = self.wait_response()
        id, flash_addr, layout = layout.split(b'/')
        assert id == b'@Internal Flash  '
        flash_addr = int(flash_addr, 16)
        pages = []
        for chunk in layout.split(b','):
            n, sz = chunk.split(b'*')
            n = int(n)
            assert sz.endswith(b'Kg')
            sz = int(sz[:-2]) * 1024
            for i in range(n):
                pages.append((flash_addr, sz))
                flash_addr += sz
        return pages

    def pageerase(self, addr):
        self.i2c.writeto(self.addr, struct.pack('<BI', I2C_CMD_PAGEERASE, addr))
        self.wait_empty_response()

    def setrdaddr(self, addr):
        self.i2c.writeto(self.addr, struct.pack('<BI', I2C_CMD_SETRDADDR, addr))
        self.wait_empty_response()

    def setwraddr(self, addr):
        self.i2c.writeto(self.addr, struct.pack('<BI', I2C_CMD_SETWRADDR, addr))
        self.wait_empty_response()

    def read(self, n):
        self.i2c.writeto(self.addr, struct.pack('<BB', I2C_CMD_READ, n))
        return self.wait_response()

    def write(self, buf):
        self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_WRITE) + buf)
        self.wait_empty_response()

    def calchash(self, n):
        self.i2c.writeto(self.addr, struct.pack('<BI', I2C_CMD_CALCHASH, n))
        return self.wait_response()

    def markvalid(self):
        self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_MARKVALID))
        self.wait_empty_response()

    def deployfile(self, filename, addr):
        pages = self.getlayout()
        page_erased = [False] * len(pages)
        buf = bytearray(128) # maximum payload supported by I2C protocol
        start_addr = addr
        self.setwraddr(addr)
        fsize = os.stat(filename)[6]
        local_sha = hashlib.sha256()
        print('Deploying %s to location 0x%08x' % (filename, addr))
        with open(filename, 'rb') as f:
            t0 = time.ticks_ms()
            while True:
                n = f.readinto(buf)
                if n == 0:
                    break

                # check if we need to erase the page
                for i, p in enumerate(pages):
                    if p[0] <= addr < p[0] + p[1]:
                        # found page
                        if not page_erased[i]:
                            print('\r% 3u%% erase 0x%08x' % (100 * (addr - start_addr) // fsize, addr), end='')
                            self.pageerase(addr)
                            page_erased[i] = True
                        break
                else:
                    raise Exception('address 0x%08x not valid' % addr)

                # write the data
                self.write(buf)

                # update local SHA256, with validity bits set
                if addr == start_addr:
                    buf[0] |= 3
                if n == len(buf):
                    local_sha.update(buf)
                else:
                    local_sha.update(buf[:n])

                addr += n
                ntotal = addr - start_addr
                if ntotal % 2048 == 0 or ntotal == fsize:
                    print('\r% 3u%% % 7u bytes   ' % (100 * ntotal // fsize, ntotal), end='')
            t1 = time.ticks_ms()
        print()
        print('rate: %.2f KiB/sec' % (1024 * ntotal / (t1 - t0) / 1000))

        local_sha = local_sha.digest()
        print('Local SHA256: ', ''.join('%02x' % x for x in local_sha))

        self.setrdaddr(start_addr)
        remote_sha = self.calchash(ntotal)
        print('Remote SHA256:', ''.join('%02x' % x for x in remote_sha))

        if local_sha == remote_sha:
            print('Marking app firmware as valid')
            self.markvalid()

        self.reset()