initialise
[debian/goprotobuf.git] / proto / encode.go
1 // Go support for Protocol Buffers - Google's data interchange format
2 //
3 // Copyright 2010 Google Inc.  All rights reserved.
4 // http://code.google.com/p/goprotobuf/
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32 package proto
33
34 /*
35  * Routines for encoding data into the wire format for protocol buffers.
36  */
37
38 import (
39         "os"
40         "reflect"
41         "runtime"
42         "unsafe"
43 )
44
45 // ErrRequiredNotSet is the error returned if Marshal is called with
46 // a protocol buffer struct whose required fields have not
47 // all been initialized. It is also the error returned if Unmarshal is
48 // called with an encoded protocol buffer that does not include all the
49 // required fields.
50 type ErrRequiredNotSet struct {
51         t reflect.Type
52 }
53
54 func (e *ErrRequiredNotSet) String() string {
55         return "proto: required fields not set in " + e.t.String()
56 }
57
58 var (
59         // ErrRepeatedHasNil is the error returned if Marshal is called with
60         // a struct with a repeated field containing a nil element.
61         ErrRepeatedHasNil = os.NewError("proto: repeated field has nil")
62
63         // ErrNil is the error returned if Marshal is called with nil.
64         ErrNil = os.NewError("proto: Marshal called with nil")
65
66         // ErrNotPtr is the error returned if Marshal is called with a non-pointer.
67         ErrNotPtr = os.NewError("proto: Marshal called with a non-pointer")
68 )
69
70 // The fundamental encoders that put bytes on the wire.
71 // Those that take integer types all accept uint64 and are
72 // therefore of type valueEncoder.
73
74 const maxVarintBytes = 10 // maximum length of a varint
75
76 // EncodeVarint returns the varint encoding of x.
77 // This is the format for the
78 // int32, int64, uint32, uint64, bool, and enum
79 // protocol buffer types.
80 // Not used by the package itself, but helpful to clients
81 // wishing to use the same encoding.
82 func EncodeVarint(x uint64) []byte {
83         var buf [maxVarintBytes]byte
84         var n int
85         for n = 0; x > 127; n++ {
86                 buf[n] = 0x80 | uint8(x&0x7F)
87                 x >>= 7
88         }
89         buf[n] = uint8(x)
90         n++
91         return buf[0:n]
92 }
93
94 // EncodeVarint writes a varint-encoded integer to the Buffer.
95 // This is the format for the
96 // int32, int64, uint32, uint64, bool, and enum
97 // protocol buffer types.
98 func (p *Buffer) EncodeVarint(x uint64) os.Error {
99         for x >= 1<<7 {
100                 p.buf = append(p.buf, uint8(x&0x7f|0x80))
101                 x >>= 7
102         }
103         p.buf = append(p.buf, uint8(x))
104         return nil
105 }
106
107 // EncodeFixed64 writes a 64-bit integer to the Buffer.
108 // This is the format for the
109 // fixed64, sfixed64, and double protocol buffer types.
110 func (p *Buffer) EncodeFixed64(x uint64) os.Error {
111         p.buf = append(p.buf,
112                 uint8(x),
113                 uint8(x>>8),
114                 uint8(x>>16),
115                 uint8(x>>24),
116                 uint8(x>>32),
117                 uint8(x>>40),
118                 uint8(x>>48),
119                 uint8(x>>56))
120         return nil
121 }
122
123 // EncodeFixed32 writes a 32-bit integer to the Buffer.
124 // This is the format for the
125 // fixed32, sfixed32, and float protocol buffer types.
126 func (p *Buffer) EncodeFixed32(x uint64) os.Error {
127         p.buf = append(p.buf,
128                 uint8(x),
129                 uint8(x>>8),
130                 uint8(x>>16),
131                 uint8(x>>24))
132         return nil
133 }
134
135 // EncodeZigzag64 writes a zigzag-encoded 64-bit integer
136 // to the Buffer.
137 // This is the format used for the sint64 protocol buffer type.
138 func (p *Buffer) EncodeZigzag64(x uint64) os.Error {
139         // use signed number to get arithmetic right shift.
140         return p.EncodeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63))))
141 }
142
143 // EncodeZigzag32 writes a zigzag-encoded 32-bit integer
144 // to the Buffer.
145 // This is the format used for the sint32 protocol buffer type.
146 func (p *Buffer) EncodeZigzag32(x uint64) os.Error {
147         // use signed number to get arithmetic right shift.
148         return p.EncodeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31))))
149 }
150
151 // EncodeRawBytes writes a count-delimited byte buffer to the Buffer.
152 // This is the format used for the bytes protocol buffer
153 // type and for embedded messages.
154 func (p *Buffer) EncodeRawBytes(b []byte) os.Error {
155         lb := len(b)
156         p.EncodeVarint(uint64(lb))
157         p.buf = append(p.buf, b...)
158         return nil
159 }
160
161 // EncodeStringBytes writes an encoded string to the Buffer.
162 // This is the format used for the proto2 string type.
163 func (p *Buffer) EncodeStringBytes(s string) os.Error {
164
165         // this works because strings and slices are the same.
166         y := *(*[]byte)(unsafe.Pointer(&s))
167         p.EncodeRawBytes(y)
168         return nil
169 }
170
171 // Marshaler is the interface representing objects that can marshal themselves.
172 type Marshaler interface {
173         Marshal() ([]byte, os.Error)
174 }
175
176 // Marshal takes the protocol buffer struct represented by pb
177 // and encodes it into the wire format, returning the data.
178 func Marshal(pb interface{}) ([]byte, os.Error) {
179         // Can the object marshal itself?
180         if m, ok := pb.(Marshaler); ok {
181                 return m.Marshal()
182         }
183         p := NewBuffer(nil)
184         err := p.Marshal(pb)
185         if err != nil {
186                 return nil, err
187         }
188         return p.buf, err
189 }
190
191 // Marshal takes the protocol buffer struct represented by pb
192 // and encodes it into the wire format, writing the result to the
193 // Buffer.
194 func (p *Buffer) Marshal(pb interface{}) os.Error {
195         // Can the object marshal itself?
196         if m, ok := pb.(Marshaler); ok {
197                 data, err := m.Marshal()
198                 if err != nil {
199                         return err
200                 }
201                 p.buf = append(p.buf, data...)
202                 return nil
203         }
204
205         mstat := runtime.MemStats.Mallocs
206
207         t, b, err := getbase(pb)
208         if t.Kind() != reflect.Ptr {
209                 return ErrNotPtr
210         }
211         if err == nil {
212                 err = p.enc_struct(t.Elem(), b)
213         }
214
215         mstat = runtime.MemStats.Mallocs - mstat
216         stats.Emalloc += mstat
217         stats.Encode++
218
219         return err
220 }
221
222 // Individual type encoders.
223
224 // Encode a bool.
225 func (o *Buffer) enc_bool(p *Properties, base uintptr) os.Error {
226         v := *(**uint8)(unsafe.Pointer(base + p.offset))
227         if v == nil {
228                 return ErrNil
229         }
230         x := *v
231         if x != 0 {
232                 x = 1
233         }
234         o.buf = append(o.buf, p.tagcode...)
235         p.valEnc(o, uint64(x))
236         return nil
237 }
238
239 // Encode an int32.
240 func (o *Buffer) enc_int32(p *Properties, base uintptr) os.Error {
241         v := *(**uint32)(unsafe.Pointer(base + p.offset))
242         if v == nil {
243                 return ErrNil
244         }
245         x := *v
246         o.buf = append(o.buf, p.tagcode...)
247         p.valEnc(o, uint64(x))
248         return nil
249 }
250
251 // Encode an int64.
252 func (o *Buffer) enc_int64(p *Properties, base uintptr) os.Error {
253         v := *(**uint64)(unsafe.Pointer(base + p.offset))
254         if v == nil {
255                 return ErrNil
256         }
257         x := *v
258         o.buf = append(o.buf, p.tagcode...)
259         p.valEnc(o, uint64(x))
260         return nil
261 }
262
263 // Encode a string.
264 func (o *Buffer) enc_string(p *Properties, base uintptr) os.Error {
265         v := *(**string)(unsafe.Pointer(base + p.offset))
266         if v == nil {
267                 return ErrNil
268         }
269         x := *v
270         o.buf = append(o.buf, p.tagcode...)
271         o.EncodeStringBytes(x)
272         return nil
273 }
274
275 // All protocol buffer fields are nillable, but be careful.
276 func isNil(v reflect.Value) bool {
277         switch v.Kind() {
278         case reflect.Map, reflect.Ptr, reflect.Slice:
279                 return v.IsNil()
280         }
281         return false
282 }
283
284 // Encode a message struct.
285 func (o *Buffer) enc_struct_message(p *Properties, base uintptr) os.Error {
286         // Can the object marshal itself?
287         iv := unsafe.Unreflect(p.stype, unsafe.Pointer(base+p.offset))
288         if m, ok := iv.(Marshaler); ok {
289                 if isNil(reflect.ValueOf(iv)) {
290                         return ErrNil
291                 }
292                 data, err := m.Marshal()
293                 if err != nil {
294                         return err
295                 }
296                 o.buf = append(o.buf, p.tagcode...)
297                 o.EncodeRawBytes(data)
298                 return nil
299         }
300         v := *(**struct{})(unsafe.Pointer(base + p.offset))
301         if v == nil {
302                 return ErrNil
303         }
304
305         // need the length before we can write out the message itself,
306         // so marshal into a separate byte buffer first.
307         obuf := o.buf
308         o.buf = o.bufalloc()
309
310         b := uintptr(unsafe.Pointer(v))
311         typ := p.stype.Elem()
312         err := o.enc_struct(typ, b)
313
314         nbuf := o.buf
315         o.buf = obuf
316         if err != nil {
317                 o.buffree(nbuf)
318                 return err
319         }
320         o.buf = append(o.buf, p.tagcode...)
321         o.EncodeRawBytes(nbuf)
322         o.buffree(nbuf)
323         return nil
324 }
325
326 // Encode a group struct.
327 func (o *Buffer) enc_struct_group(p *Properties, base uintptr) os.Error {
328         v := *(**struct{})(unsafe.Pointer(base + p.offset))
329         if v == nil {
330                 return ErrNil
331         }
332
333         o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
334         b := uintptr(unsafe.Pointer(v))
335         typ := p.stype.Elem()
336         err := o.enc_struct(typ, b)
337         if err != nil {
338                 return err
339         }
340         o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup))
341         return nil
342 }
343
344 // Encode a slice of bools ([]bool).
345 func (o *Buffer) enc_slice_bool(p *Properties, base uintptr) os.Error {
346         s := *(*[]uint8)(unsafe.Pointer(base + p.offset))
347         l := len(s)
348         if l == 0 {
349                 return ErrNil
350         }
351         for _, x := range s {
352                 o.buf = append(o.buf, p.tagcode...)
353                 if x != 0 {
354                         x = 1
355                 }
356                 p.valEnc(o, uint64(x))
357         }
358         return nil
359 }
360
361 // Encode a slice of bools ([]bool) in packed format.
362 func (o *Buffer) enc_slice_packed_bool(p *Properties, base uintptr) os.Error {
363         s := *(*[]uint8)(unsafe.Pointer(base + p.offset))
364         l := len(s)
365         if l == 0 {
366                 return ErrNil
367         }
368         o.buf = append(o.buf, p.tagcode...)
369         o.EncodeVarint(uint64(l)) // each bool takes exactly one byte
370         for _, x := range s {
371                 if x != 0 {
372                         x = 1
373                 }
374                 p.valEnc(o, uint64(x))
375         }
376         return nil
377 }
378
379 // Encode a slice of bytes ([]byte).
380 func (o *Buffer) enc_slice_byte(p *Properties, base uintptr) os.Error {
381         s := *(*[]uint8)(unsafe.Pointer(base + p.offset))
382         if s == nil {
383                 return ErrNil
384         }
385         o.buf = append(o.buf, p.tagcode...)
386         o.EncodeRawBytes(s)
387         return nil
388 }
389
390 // Encode a slice of int32s ([]int32).
391 func (o *Buffer) enc_slice_int32(p *Properties, base uintptr) os.Error {
392         s := *(*[]uint32)(unsafe.Pointer(base + p.offset))
393         l := len(s)
394         if l == 0 {
395                 return ErrNil
396         }
397         for i := 0; i < l; i++ {
398                 o.buf = append(o.buf, p.tagcode...)
399                 x := s[i]
400                 p.valEnc(o, uint64(x))
401         }
402         return nil
403 }
404
405 // Encode a slice of int32s ([]int32) in packed format.
406 func (o *Buffer) enc_slice_packed_int32(p *Properties, base uintptr) os.Error {
407         s := *(*[]uint32)(unsafe.Pointer(base + p.offset))
408         l := len(s)
409         if l == 0 {
410                 return ErrNil
411         }
412         // TODO: Reuse a Buffer.
413         buf := NewBuffer(nil)
414         for i := 0; i < l; i++ {
415                 p.valEnc(buf, uint64(s[i]))
416         }
417
418         o.buf = append(o.buf, p.tagcode...)
419         o.EncodeVarint(uint64(len(buf.buf)))
420         o.buf = append(o.buf, buf.buf...)
421         return nil
422 }
423
424 // Encode a slice of int64s ([]int64).
425 func (o *Buffer) enc_slice_int64(p *Properties, base uintptr) os.Error {
426         s := *(*[]uint64)(unsafe.Pointer(base + p.offset))
427         l := len(s)
428         if l == 0 {
429                 return ErrNil
430         }
431         for i := 0; i < l; i++ {
432                 o.buf = append(o.buf, p.tagcode...)
433                 x := s[i]
434                 p.valEnc(o, uint64(x))
435         }
436         return nil
437 }
438
439 // Encode a slice of int64s ([]int64) in packed format.
440 func (o *Buffer) enc_slice_packed_int64(p *Properties, base uintptr) os.Error {
441         s := *(*[]uint64)(unsafe.Pointer(base + p.offset))
442         l := len(s)
443         if l == 0 {
444                 return ErrNil
445         }
446         // TODO: Reuse a Buffer.
447         buf := NewBuffer(nil)
448         for i := 0; i < l; i++ {
449                 p.valEnc(buf, s[i])
450         }
451
452         o.buf = append(o.buf, p.tagcode...)
453         o.EncodeVarint(uint64(len(buf.buf)))
454         o.buf = append(o.buf, buf.buf...)
455         return nil
456 }
457
458 // Encode a slice of slice of bytes ([][]byte).
459 func (o *Buffer) enc_slice_slice_byte(p *Properties, base uintptr) os.Error {
460         ss := *(*[][]uint8)(unsafe.Pointer(base + p.offset))
461         l := len(ss)
462         if l == 0 {
463                 return ErrNil
464         }
465         for i := 0; i < l; i++ {
466                 o.buf = append(o.buf, p.tagcode...)
467                 s := ss[i]
468                 o.EncodeRawBytes(s)
469         }
470         return nil
471 }
472
473 // Encode a slice of strings ([]string).
474 func (o *Buffer) enc_slice_string(p *Properties, base uintptr) os.Error {
475         ss := *(*[]string)(unsafe.Pointer(base + p.offset))
476         l := len(ss)
477         for i := 0; i < l; i++ {
478                 o.buf = append(o.buf, p.tagcode...)
479                 s := ss[i]
480                 o.EncodeStringBytes(s)
481         }
482         return nil
483 }
484
485 // Encode a slice of message structs ([]*struct).
486 func (o *Buffer) enc_slice_struct_message(p *Properties, base uintptr) os.Error {
487         s := *(*[]*struct{})(unsafe.Pointer(base + p.offset))
488         l := len(s)
489         typ := p.stype.Elem()
490
491         for i := 0; i < l; i++ {
492                 v := s[i]
493                 if v == nil {
494                         return ErrRepeatedHasNil
495                 }
496
497                 // Can the object marshal itself?
498                 iv := unsafe.Unreflect(p.stype, unsafe.Pointer(&s[i]))
499                 if m, ok := iv.(Marshaler); ok {
500                         if isNil(reflect.ValueOf(iv)) {
501                                 return ErrNil
502                         }
503                         data, err := m.Marshal()
504                         if err != nil {
505                                 return err
506                         }
507                         o.buf = append(o.buf, p.tagcode...)
508                         o.EncodeRawBytes(data)
509                         continue
510                 }
511
512                 obuf := o.buf
513                 o.buf = o.bufalloc()
514
515                 b := uintptr(unsafe.Pointer(v))
516                 err := o.enc_struct(typ, b)
517
518                 nbuf := o.buf
519                 o.buf = obuf
520                 if err != nil {
521                         o.buffree(nbuf)
522                         if err == ErrNil {
523                                 return ErrRepeatedHasNil
524                         }
525                         return err
526                 }
527                 o.buf = append(o.buf, p.tagcode...)
528                 o.EncodeRawBytes(nbuf)
529
530                 o.buffree(nbuf)
531         }
532         return nil
533 }
534
535 // Encode a slice of group structs ([]*struct).
536 func (o *Buffer) enc_slice_struct_group(p *Properties, base uintptr) os.Error {
537         s := *(*[]*struct{})(unsafe.Pointer(base + p.offset))
538         l := len(s)
539         typ := p.stype.Elem()
540
541         for i := 0; i < l; i++ {
542                 v := s[i]
543                 if v == nil {
544                         return ErrRepeatedHasNil
545                 }
546
547                 o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
548
549                 b := uintptr(unsafe.Pointer(v))
550                 err := o.enc_struct(typ, b)
551
552                 if err != nil {
553                         if err == ErrNil {
554                                 return ErrRepeatedHasNil
555                         }
556                         return err
557                 }
558
559                 o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup))
560         }
561         return nil
562 }
563
564 // Encode an extension map.
565 func (o *Buffer) enc_map(p *Properties, base uintptr) os.Error {
566         v := *(*map[int32]Extension)(unsafe.Pointer(base + p.offset))
567         if err := encodeExtensionMap(v); err != nil {
568                 return err
569         }
570         for _, e := range v {
571                 o.buf = append(o.buf, e.enc...)
572         }
573         return nil
574 }
575
576 // Encode a struct.
577 func (o *Buffer) enc_struct(t reflect.Type, base uintptr) os.Error {
578         prop := GetProperties(t)
579         required := prop.reqCount
580         for _, p := range prop.Prop {
581                 if p.enc != nil {
582                         err := p.enc(o, p, base)
583                         if err != nil {
584                                 if err != ErrNil {
585                                         return err
586                                 }
587                         } else if p.Required {
588                                 required--
589                         }
590                 }
591         }
592         // See if we encoded all required fields.
593         if required > 0 {
594                 return &ErrRequiredNotSet{t}
595         }
596
597         return nil
598 }