aboutsummaryrefslogtreecommitdiffstats
path: root/OpenKeychain/src/main/java/org/sufficientlysecure/keychain/javacard/UsbTransport.java
blob: 07697f11ecb6578d75e91497d5282d10b57d39c5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
package org.sufficientlysecure.keychain.javacard;

import android.hardware.usb.UsbConstants;
import android.hardware.usb.UsbDevice;
import android.hardware.usb.UsbDeviceConnection;
import android.hardware.usb.UsbEndpoint;
import android.hardware.usb.UsbInterface;
import android.hardware.usb.UsbManager;
import android.support.annotation.NonNull;
import android.support.annotation.Nullable;
import android.util.Pair;

import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.encoders.Hex;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public class UsbTransport implements Transport {
    private static final int CLASS_SMARTCARD = 11;
    private static final int TIMEOUT = 20 * 1000; // 2 s

    private final UsbManager mUsbManager;
    private final UsbDevice mUsbDevice;
    private final UsbInterface mUsbInterface;
    private final UsbEndpoint mBulkIn;
    private final UsbEndpoint mBulkOut;
    private final UsbDeviceConnection mConnection;
    private byte mCounter = 0;

    public UsbTransport(final UsbDevice usbDevice, final UsbManager usbManager) throws TransportIoException {
        mUsbDevice = usbDevice;
        mUsbManager = usbManager;

        mUsbInterface = getSmartCardInterface(mUsbDevice);
        // throw if mUsbInterface == null
        final Pair<UsbEndpoint, UsbEndpoint> ioEndpoints = getIoEndpoints(mUsbInterface);
        mBulkIn = ioEndpoints.first;
        mBulkOut = ioEndpoints.second;
        // throw if any endpoint is null

        mConnection = mUsbManager.openDevice(mUsbDevice);
        // throw if connection is null
        mConnection.claimInterface(mUsbInterface, true);
        // check result

        powerOn();

        setTimings();
    }

    private void setTimings()  throws TransportIoException {
        byte[] data = {
                0x6C,
                0x00, 0x00, 0x00, 0x00,
                0x00,
                mCounter++,
                0x00, 0x00, 0x00
        };
        sendRaw(data);
        data = receive();

        data[0] = 0x61;
        data[1] = 0x04;
        data[2] = data[3] = data[4] = 0x00;
        data[5] = 0x00;
        data[6] = mCounter++;
        data[7] = 0x00;
        data[8] = data[9] = 0x00;

        data[13] = 1;

        sendRaw(data);
        receive();
    }

    private void powerOff() throws TransportIoException {
        final byte[] iccPowerOff = {
                0x63,
                0x00, 0x00, 0x00, 0x00,
                0x00,
                mCounter++,
                0x00,
                0x00, 0x00
        };
        sendRaw(iccPowerOff);
        receive();
    }

    void powerOn() throws TransportIoException {
        final byte[] iccPowerOn = {
                0x62,
                0x00, 0x00, 0x00, 0x00,
                0x00,
                mCounter++,
                0x00,
                0x00, 0x00
        };
        sendRaw(iccPowerOn);
        receive();
    }

    /**
     * Get first class 11 (Chip/Smartcard) interface for the device
     *
     * @param device {@link UsbDevice} which will be searched
     * @return {@link UsbInterface} of smartcard or null if it doesn't exist
     */
    @Nullable
    private static UsbInterface getSmartCardInterface(final UsbDevice device) {
        for (int i = 0; i < device.getInterfaceCount(); i++) {
            final UsbInterface anInterface = device.getInterface(i);
            if (anInterface.getInterfaceClass() == CLASS_SMARTCARD) {
                return anInterface;
            }
        }
        return null;
    }

    @NonNull
    private static Pair<UsbEndpoint, UsbEndpoint> getIoEndpoints(final UsbInterface usbInterface) {
        UsbEndpoint bulkIn = null, bulkOut = null;
        for (int i = 0; i < usbInterface.getEndpointCount(); i++) {
            final UsbEndpoint endpoint = usbInterface.getEndpoint(i);
            if (endpoint.getType() != UsbConstants.USB_ENDPOINT_XFER_BULK) {
                continue;
            }

            if (endpoint.getDirection() == UsbConstants.USB_DIR_IN) {
                bulkIn = endpoint;
            } else if (endpoint.getDirection() == UsbConstants.USB_DIR_OUT) {
                bulkOut = endpoint;
            }
        }
        return new Pair<>(bulkIn, bulkOut);
    }

    @Override
    public void release() {
        mConnection.releaseInterface(mUsbInterface);
        mConnection.close();
    }

    @Override
    public boolean isConnected() {
        // TODO: redo
        return mUsbManager.getDeviceList().containsValue(mUsbDevice);
    }

    @Override
    public byte[] sendAndReceive(byte[] data) throws TransportIoException {
        send(data);
        byte[] bytes;
        do {
            bytes = receive();
        } while (isXfrBlockNotReady(bytes));

        checkXfrBlockResult(bytes);
        return Arrays.copyOfRange(bytes, 10, bytes.length);
    }

    public void send(byte[] d) throws TransportIoException {
        int l = d.length;
        byte[] data = Arrays.concatenate(new byte[]{
                        0x6f,
                        (byte) l, (byte) (l >> 8), (byte) (l >> 16), (byte) (l >> 24),
                        0x00,
                        mCounter++,
                        0x00,
                        0x00, 0x00},
                d);

        int send = 0;
        while (send < data.length) {
            final int len = Math.min(mBulkIn.getMaxPacketSize(), data.length - send);
            sendRaw(Arrays.copyOfRange(data, send, send + len));
            send += len;
        }
    }

    public byte[] receive() throws TransportIoException {
        byte[] buffer = new byte[mBulkIn.getMaxPacketSize()];
        byte[] result = null;
        int readBytes = 0, totalBytes = 0;

        do {
            int res = mConnection.bulkTransfer(mBulkIn, buffer, buffer.length, TIMEOUT);
            if (res < 0) {
                throw new TransportIoException("USB error, failed to receive response " + res);
            }
            if (result == null) {
                if (res < 10) {
                    throw new TransportIoException("USB error, failed to receive ccid header");
                }
                totalBytes = ByteBuffer.wrap(buffer, 1, 4).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer().get() + 10;
                result = new byte[totalBytes];
            }
            System.arraycopy(buffer, 0, result, readBytes, res);
            readBytes += res;
        } while (readBytes < totalBytes);

        return result;
    }

    private void sendRaw(final byte[] data) throws TransportIoException {
        final int tr1 = mConnection.bulkTransfer(mBulkOut, data, data.length, TIMEOUT);
        if (tr1 != data.length) {
            throw new TransportIoException("USB error, failed to send data " + tr1);
        }
    }

    private byte getStatus(byte[] bytes) {
        return (byte) ((bytes[7] >> 6) & 0x03);
    }

    private void checkXfrBlockResult(byte[] bytes) throws TransportIoException {
        final byte status = getStatus(bytes);
        if (status != 0) {
            throw new TransportIoException("CCID error, status " + status + " error code: " + Hex.toHexString(bytes, 8, 1));
        }
    }

    private boolean isXfrBlockNotReady(byte[] bytes) {
        return getStatus(bytes) == 2;
    }
}