Skip to content

Refactor ota package #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions command/ota/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package ota

import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"os"

Expand Down Expand Up @@ -51,17 +51,16 @@ func Generate(binFile string, outFile string, fqbn string) error {
return err
}

var w bytes.Buffer
otaWriter := inota.NewWriter(&w, arduinoVendorID, productID)
_, err = otaWriter.Write(data)
out, err := os.Create(outFile)
if err != nil {
return err
}
otaWriter.Close()
defer out.Close()

err = ioutil.WriteFile(outFile, w.Bytes(), os.FileMode(0644))
enc := inota.NewEncoder(out, arduinoVendorID, productID)
err = enc.Encode(data)
if err != nil {
return err
return fmt.Errorf("failed to encode binary file: %w", err)
}

return nil
Expand Down
1 change: 1 addition & 0 deletions command/ota/massupload.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func run(uploader otaUploader, ids []string, otaFile string, expiration int) []R
results = append(results, r)
continue
}
defer file.Close()
jobs <- job{id: id, file: file}
}
close(jobs)
Expand Down
1 change: 1 addition & 0 deletions command/ota/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func Upload(params *UploadParams, cred *config.Credentials) error {
if err != nil {
return fmt.Errorf("%s: %w", "cannot open ota file", err)
}
defer file.Close()

expiration := otaExpirationMins
if params.Deferred {
Expand Down
102 changes: 36 additions & 66 deletions internal/ota/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,60 +18,48 @@
package ota

import (
"bufio"
"encoding/binary"
"fmt"
"hash/crc32"
"io"
"strconv"

"github.com/arduino/arduino-cloud-cli/internal/lzss"
"github.com/juju/errors"
)

// A writer is a buffered, flushable writer.
type writer interface {
io.Writer
Flush() error
}

// encoder encodes a binary into an .ota file.
type encoder struct {
// w is the writer that compressed bytes are written to.
w writer
// Encoder writes a binary to an output stream in the ota format.
type Encoder struct {
// w is the stream where encoded bytes are written.
w io.Writer

// vendorID is the ID of the board vendor
// vendorID is the ID of the board vendor.
vendorID string

// is the ID of the board vendor is the ID of the board model
// productID is the ID of the board model.
productID string
}

// NewWriter creates a new `WriteCloser` for the the given VID/PID.
func NewWriter(w io.Writer, vendorID, productID string) io.WriteCloser {
bw, ok := w.(writer)
if !ok {
bw = bufio.NewWriter(w)
}
return &encoder{
w: bw,
// NewEncoder creates a new ota encoder.
func NewEncoder(w io.Writer, vendorID, productID string) *Encoder {
return &Encoder{
w: w,
vendorID: vendorID,
productID: productID,
}
}

// Write writes a compressed representation of p to e's underlying writer.
func (e *encoder) Write(binaryData []byte) (int, error) {
//log.Println("original binaryData is", len(binaryData), "bytes length")

// Magic number (VID/PID)
// Encode compresses data using a lzss algorithm, encodes the result
// in ota format and writes it to e's underlying writer.
func (e *Encoder) Encode(data []byte) error {
// Compute the magic number (VID/PID)
magicNumber := make([]byte, 4)
vid, err := strconv.ParseUint(e.vendorID, 16, 16)
if err != nil {
return 0, errors.Annotate(err, "OTA encoder: failed to parse vendorID")
return fmt.Errorf("cannot parse vendorID: %w", err)
}
pid, err := strconv.ParseUint(e.productID, 16, 16)
if err != nil {
return 0, errors.Annotate(err, "OTA encoder: failed to parse productID")
return fmt.Errorf("cannot parse productID: %w", err)
}

binary.LittleEndian.PutUint16(magicNumber[0:2], uint16(pid))
Expand All @@ -82,61 +70,43 @@ func (e *encoder) Write(binaryData []byte) (int, error) {
Compression: true,
}

// Compress the compiled binary
compressed := lzss.Encode(binaryData)

compressed := lzss.Encode(data)
// Prepend magic number and version field to payload
var binDataComplete []byte
binDataComplete = append(binDataComplete, magicNumber...)
binDataComplete = append(binDataComplete, version.AsBytes()...)
binDataComplete = append(binDataComplete, compressed...)
//log.Println("binDataComplete is", len(binDataComplete), "bytes length")
var outData []byte
outData = append(outData, magicNumber...)
outData = append(outData, version.Bytes()...)
outData = append(outData, compressed...)

headerSize, err := e.writeHeader(binDataComplete)
err = e.writeHeader(outData)
if err != nil {
return headerSize, err
return fmt.Errorf("cannot write data header to output stream: %w", err)
}

payloadSize, err := e.writePayload(binDataComplete)
_, err = e.w.Write(outData)
if err != nil {
return payloadSize, err
return fmt.Errorf("cannot write encoded data to output stream: %w", err)
}

return headerSize + payloadSize, nil
}

// Close closes the encoder, flushing any pending output. It does not close or
// flush e's underlying writer.
func (e *encoder) Close() error {
return e.w.Flush()
return nil
}

func (e *encoder) writeHeader(binDataComplete []byte) (int, error) {

func (e *Encoder) writeHeader(data []byte) error {
// Write the length of the content
lengthAsBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(lengthAsBytes, uint32(len(binDataComplete)))

n, err := e.w.Write(lengthAsBytes)
binary.LittleEndian.PutUint32(lengthAsBytes, uint32(len(data)))
_, err := e.w.Write(lengthAsBytes)
if err != nil {
return n, err
return err
}

// Calculate the checksum for binDataComplete
crc := crc32.ChecksumIEEE(binDataComplete)

// encode the checksum uint32 value as 4 bytes
// Write the checksum uint32 value as 4 bytes
crc := crc32.ChecksumIEEE(data)
crcAsBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(crcAsBytes, crc)

n, err = e.w.Write(crcAsBytes)
_, err = e.w.Write(crcAsBytes)
if err != nil {
return n, err
return err
}

return len(lengthAsBytes) + len(crcAsBytes), nil
}

func (e *encoder) writePayload(data []byte) (int, error) {
return e.w.Write(data)
return nil
}
65 changes: 54 additions & 11 deletions internal/ota/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package ota
import (
"bytes"
"encoding/hex"
"log"
"io/ioutil"

"fmt"
"hash/crc32"
Expand All @@ -37,30 +37,25 @@ func TestComputeCrc32Checksum(t *testing.T) {
assert.Equal(t, crc, uint32(2090640218))
}

func TestEncoderWrite(t *testing.T) {

func TestEncode(t *testing.T) {
// Setup test data
data, _ := hex.DecodeString("DEADBEEF") // uncompressed, or 'ef 6b 77 de f0' (compressed w/ LZSS)

var w bytes.Buffer
vendorID := "2341" // Arduino
productID := "8054" // MRK Wifi 1010

otaWriter := NewWriter(&w, vendorID, productID)
defer otaWriter.Close()
enc := NewEncoder(&w, vendorID, productID)

n, err := otaWriter.Write(data)
err := enc.Encode(data)
if err != nil {
t.Error(err)
t.Fail()
}
log.Println("written ota of", n, "bytes length")

otaWriter.Close()
actual := w.Bytes()

// You can get the expected result creating an `.ota` file using Alex's tools:
// https://github.com/arduino-libraries/ArduinoIoTCloud/tree/master/extras/tools
// Expected result has been computed with the following tool:
// https://github.com/arduino-libraries/ArduinoIoTCloud/tree/master/extras/tools .
expected, _ := hex.DecodeString("11000000a1744bd4548041230000000000000040ef6b77def0")

res := bytes.Compare(expected, actual)
Expand All @@ -72,3 +67,51 @@ func TestEncoderWrite(t *testing.T) {

assert.Assert(t, res == 0) // 0 means equal
}

// Expected '.ota' files contained in testdata have been computed with the following tool:
// https://github.com/arduino-libraries/ArduinoIoTCloud/tree/master/extras/tools .
func TestEncodeFiles(t *testing.T) {
tests := []struct {
name string
infile string
outfile string
}{
{
name: "blink",
infile: "testdata/blink.bin",
outfile: "testdata/blink.ota",
},
{
name: "cloud sketch",
infile: "testdata/cloud.bin",
outfile: "testdata/cloud.ota",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
input, err := ioutil.ReadFile(tt.infile)
if err != nil {
t.Fatal("couldn't open test file")
}

want, err := ioutil.ReadFile(tt.outfile)
if err != nil {
t.Fatal("couldn't open test file")
}

var got bytes.Buffer
vendorID := "2341" // Arduino
productID := "8057" // Nano 33 IoT
otaenc := NewEncoder(&got, vendorID, productID)
err = otaenc.Encode(input)
if err != nil {
t.Error(err)
}

if !bytes.Equal(want, got.Bytes()) {
t.Error("encoding failed")
}
})
}
}
Binary file added internal/ota/testdata/blink.bin
Binary file not shown.
Binary file added internal/ota/testdata/blink.ota
Binary file not shown.
Binary file added internal/ota/testdata/cloud.bin
Binary file not shown.
Binary file added internal/ota/testdata/cloud.ota
Binary file not shown.
Binary file removed internal/ota/testdata/lorem.lzss
Binary file not shown.
9 changes: 0 additions & 9 deletions internal/ota/testdata/lorem.txt

This file was deleted.

4 changes: 2 additions & 2 deletions internal/ota/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ type Version struct {
PayloadBuildNum uint32
}

// AsBytes builds a 8 byte length representation of the Version Struct for the OTA update.
func (v *Version) AsBytes() []byte {
// Bytes builds a 8 byte length representation of the Version Struct for the OTA update.
func (v *Version) Bytes() []byte {
version := []byte{0, 0, 0, 0, 0, 0, 0, 0}

// Set compression
Expand Down
2 changes: 1 addition & 1 deletion internal/ota/version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestVersionWithCompressionEnabled(t *testing.T) {
}

expected := []byte{0, 0, 0, 0, 0, 0, 0, 0x40}
actual := version.AsBytes()
actual := version.Bytes()

// create a tabwriter for formatting the output
w := new(tabwriter.Writer)
Expand Down