Efficient recovery of the output projection matrix up to an orthogonal transform
Develop a computationally efficient algorithm to recover the transformer’s output embedding projection matrix W ∈ R^{l×h} up to an orthogonal h×h transformation using only multiple logit vectors obtained from API queries. Concretely, given logit outputs for prompts that yield points x_i = U^T W g_θ(p_i), efficiently solve the overdetermined linear system x_i^T A x_i = 1 for the positive semidefinite matrix A ∈ R^{h×h}, compute M with A = M^T M, and reconstruct W as U M^{-1} O for some orthogonal O, thereby improving the outlined orthogonal-recovery attack beyond the current infeasible h^2-variable linear solve.
References
We do not carry out this attack in practice for models considered in this paper, and leave improving this algorithm as an open problem for future work. However, we do not know how to solve these systems of linear equations in h2 variables efficiently (h>750 in all our experiments); so in practice we resort to reconstructing weights up to an arbitrary h × h matrix, as described in Appendix \ref{sec:proof_of_42}.